Skip to content

Commit

Permalink
Camel to snake replicatorOrSkip.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 256700843
Change-Id: I7648c0cf0192eaae07b9d4955fc121de727ff9f5
  • Loading branch information
tomhennigan authored and sonnet-copybara committed Jul 5, 2019
1 parent 0c1aa77 commit ccba0f9
Showing 1 changed file with 7 additions and 6 deletions.
13 changes: 7 additions & 6 deletions sonnet/src/conformance/checkpoint_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def test_restore_golden(self, golden):

class ReplicatorCheckpointTest(test_utils.TestCase, parameterized.TestCase):

def replicatorOrSkip(self, replicator_fn, use_function):
def replicator_or_skip(self, replicator_fn, use_function):
replicator = replicator_fn()
if not use_function and isinstance(replicator,
snt_replicator.TpuReplicator):
Expand All @@ -189,7 +189,7 @@ def replicatorOrSkip(self, replicator_fn, use_function):
replicator_utils.named_replicators(),
test_utils.named_bools("use_function"))
def test_save_restore(self, golden, replicator_fn, use_function):
replicator = self.replicatorOrSkip(replicator_fn, use_function)
replicator = self.replicator_or_skip(replicator_fn, use_function)

with replicator.scope():
module = golden.create_module()
Expand Down Expand Up @@ -246,7 +246,7 @@ def forward():
@test_utils.combined_named_parameters(goldens.named_goldens(),
replicator_utils.named_replicators())
def test_restore_from_golden(self, golden, replicator_fn):
replicator = self.replicatorOrSkip(replicator_fn, use_function=False)
replicator = self.replicator_or_skip(replicator_fn, use_function=False)

with replicator.scope():
module = golden.create_module()
Expand All @@ -262,7 +262,8 @@ def test_restore_from_golden(self, golden, replicator_fn):
test_utils.named_bools("use_function"))
def test_restore_from_non_distributed(self, golden, replicator_fn,
use_function):
replicator = self.replicatorOrSkip(replicator_fn, use_function)
replicator = self.replicator_or_skip(replicator_fn, use_function)

# Save a checkpoint from a non-distributed model.
module = golden.create_module()
normal_variables = golden.create_all_variables(module)
Expand Down Expand Up @@ -310,7 +311,7 @@ def run_forward(module):
@test_utils.combined_named_parameters(goldens.named_goldens(),
replicator_utils.named_replicators())
def test_restore_on_create(self, golden, replicator_fn):
replicator = self.replicatorOrSkip(replicator_fn, use_function=False)
replicator = self.replicator_or_skip(replicator_fn, use_function=False)

# Save a checkpoint from a non-distributed model.
module = golden.create_module()
Expand Down Expand Up @@ -343,7 +344,7 @@ def test_restore_on_create_in_replica_context(self, golden, replicator_fn,
self.skipTest("Currently not working as expected on multiple devices")
# TODO(b/134376796) renable this once bug is fixed

replicator = self.replicatorOrSkip(replicator_fn, use_function)
replicator = self.replicator_or_skip(replicator_fn, use_function)

# Save a checkpoint from a non-distributed model.
module = golden.create_module()
Expand Down

0 comments on commit ccba0f9

Please sign in to comment.