Skip to content

Commit

Permalink
Check for a compatible distribution strategy on every call to apply.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 258574659
Change-Id: If4f255932e79cc71b559472fdb8792c488d433b6
  • Loading branch information
petebu authored and sonnet-copybara committed Jul 17, 2019
1 parent bc63faf commit 0bfca89
Show file tree
Hide file tree
Showing 5 changed files with 12 additions and 12 deletions.
4 changes: 2 additions & 2 deletions sonnet/src/adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@ def __init__(self,

@once.once
def _initialize(self, parameters):
optimizer_utils.check_strategy()
zero_var = lambda p: utils.variable_like(p, trainable=False)
with tf.name_scope("m"):
self.m.extend(zero_var(p) for p in parameters)
Expand Down Expand Up @@ -99,6 +98,7 @@ def apply(self, updates, parameters):
ValueError: If `updates` and `parameters` are empty, have different
lengths, or have inconsistent types.
"""
optimizer_utils.check_distribution_strategy()
optimizer_utils.check_updates_parameters(updates, parameters)
self._initialize(parameters)
self.step.assign_add(1)
Expand Down Expand Up @@ -153,7 +153,6 @@ def __init__(self,

@once.once
def _initialize(self, parameters):
optimizer_utils.check_strategy()
zero_var = lambda p: utils.variable_like(p, trainable=False)
with tf.name_scope("m"):
self.m.extend(zero_var(p) for p in parameters)
Expand All @@ -162,6 +161,7 @@ def _initialize(self, parameters):

def apply(self, updates, parameters):
"""Applies updates to parameters."""
optimizer_utils.check_distribution_strategy()
optimizer_utils.check_updates_parameters(updates, parameters)
self._initialize(parameters)
self.step.assign_add(1)
Expand Down
4 changes: 2 additions & 2 deletions sonnet/src/momentum.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ def __init__(self, learning_rate, momentum, use_nesterov=False, name=None):

@once.once
def _initialize(self, parameters):
optimizer_utils.check_strategy()
with tf.name_scope("accumulated_momentum"):
self.accumulated_momentum.extend(
utils.variable_like(p, trainable=False) for p in parameters)
Expand Down Expand Up @@ -83,6 +82,7 @@ def apply(self, updates, parameters):
ValueError: If `updates` and `parameters` are empty, have different
lengths, or have inconsistent types.
"""
optimizer_utils.check_distribution_strategy()
optimizer_utils.check_updates_parameters(updates, parameters)
self._initialize(parameters)
for update, parameter, momentum in zip(
Expand Down Expand Up @@ -124,13 +124,13 @@ def __init__(self, learning_rate, momentum, use_nesterov=False, name=None):

@once.once
def _initialize(self, parameters):
optimizer_utils.check_strategy()
with tf.name_scope("accumulated_momentum"):
self.accumulated_momentum.extend(
utils.variable_like(p, trainable=False) for p in parameters)

def apply(self, updates, parameters):
"""Applies updates to parameters."""
optimizer_utils.check_distribution_strategy()
optimizer_utils.check_updates_parameters(updates, parameters)
self._initialize(parameters)
for update, parameter, accumulated_momentum in zip(
Expand Down
8 changes: 4 additions & 4 deletions sonnet/src/optimizer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,22 +26,22 @@
# a simplified update model and replica local variables.
# TODO(cjfj,petebu,tomhennigan) Add async parameter server strategy when needed.
# TODO(cjfj,petebu,tomhennigan) Add sync multi-worker GPU strategy when needed.
SUPPORTED_STRATEGIES = (
_SUPPORTED_STRATEGIES = (
tf.distribute.OneDeviceStrategy,
replicator.Replicator,
replicator.TpuReplicator,
)


def check_strategy():
def check_distribution_strategy():
if tf.distribute.has_strategy():
strategy = tf.distribute.get_strategy()
if not isinstance(strategy, SUPPORTED_STRATEGIES):
if not isinstance(strategy, _SUPPORTED_STRATEGIES):
raise ValueError(
"Sonnet optimizers are not compatible with `{}`. "
"Please use one of `{}` instead.".format(
strategy.__class__.__name__,
"`, `".join(s.__name__ for s in SUPPORTED_STRATEGIES)))
"`, `".join(s.__name__ for s in _SUPPORTED_STRATEGIES)))


def check_updates_parameters(updates, parameters):
Expand Down
4 changes: 2 additions & 2 deletions sonnet/src/rmsprop.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,6 @@ def __init__(self, learning_rate, decay=0.9, momentum=0.0, epsilon=1e-10,

@once.once
def _initialize(self, parameters):
optimizer_utils.check_strategy()
zero_var = lambda p: utils.variable_like(p, trainable=False)
with tf.name_scope("momentum"):
self.mom.extend(zero_var(p) for p in parameters)
Expand All @@ -110,6 +109,7 @@ def apply(self, updates, parameters):
ValueError: If `updates` and `parameters` are empty, have different
lengths, or have inconsistent types.
"""
optimizer_utils.check_distribution_strategy()
optimizer_utils.check_updates_parameters(updates, parameters)
self._initialize(parameters)
for update, parameter, mom, ms, mg in six.moves.zip_longest(
Expand Down Expand Up @@ -167,7 +167,6 @@ def __init__(self, learning_rate, decay=0.9, momentum=0.0, epsilon=1e-10,

@once.once
def _initialize(self, parameters):
optimizer_utils.check_strategy()
zero_var = lambda p: utils.variable_like(p, trainable=False)
with tf.name_scope("momentum"):
self.mom.extend(zero_var(p) for p in parameters)
Expand All @@ -179,6 +178,7 @@ def _initialize(self, parameters):

def apply(self, updates, parameters):
"""Applies updates to parameters."""
optimizer_utils.check_distribution_strategy()
optimizer_utils.check_updates_parameters(updates, parameters)
self._initialize(parameters)
for update, parameter, mom, ms, mg in six.moves.zip_longest(
Expand Down
4 changes: 2 additions & 2 deletions sonnet/src/sgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@ def apply(self, updates, parameters):
ValueError: If `updates` and `parameters` are empty, have different
lengths, or have inconsistent types.
"""
optimizer_utils.check_distribution_strategy()
optimizer_utils.check_updates_parameters(updates, parameters)
optimizer_utils.check_strategy()
for update, parameter in zip(updates, parameters):
if update is not None:
optimizer_utils.check_same_dtype(update, parameter)
Expand All @@ -77,8 +77,8 @@ def __init__(self, learning_rate, name=None):

def apply(self, updates, parameters):
"""Applies updates to parameters."""
optimizer_utils.check_distribution_strategy()
optimizer_utils.check_updates_parameters(updates, parameters)
optimizer_utils.check_strategy()
for update, parameter in zip(updates, parameters):
if update is not None:
optimizer_utils.check_same_dtype(update, parameter)
Expand Down

0 comments on commit 0bfca89

Please sign in to comment.