Skip to content

Commit

Permalink
Use legacy keras optimizer to be compatible with an incoming Keras op…
Browse files Browse the repository at this point in the history
…timizer migration.

PiperOrigin-RevId: 467992914
  • Loading branch information
chenmoneygithub authored and Copybara-Service committed Aug 16, 2022
1 parent d1cd371 commit c92ac8e
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 3 deletions.
5 changes: 4 additions & 1 deletion sonnet/src/optimizers/momentum_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,10 @@ def make_optimizer(self, **kwargs):
if "use_nesterov" in kwargs:
kwargs["nesterov"] = kwargs["use_nesterov"]
del kwargs["use_nesterov"]
return optimizer_tests.WrappedTFOptimizer(tf.optimizers.SGD(**kwargs))
if hasattr(tf.keras.optimizers, "legacy"):
return optimizer_tests.WrappedTFOptimizer(
tf.keras.optimizers.legacy.SGD(**kwargs))
return optimizer_tests.WrappedTFOptimizer(tf.keras.optimizers.SGD(**kwargs))


if __name__ == "__main__":
Expand Down
6 changes: 5 additions & 1 deletion sonnet/src/optimizers/rmsprop_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,11 @@ def make_optimizer(self, **kwargs):
if "learning_rate" not in kwargs:
kwargs["learning_rate"] = 0.1
kwargs["rho"] = kwargs.pop("decay", 0.9)
return optimizer_tests.WrappedTFOptimizer(tf.optimizers.RMSprop(**kwargs))
if hasattr(tf.keras.optimizers, "legacy"):
return optimizer_tests.WrappedTFOptimizer(
tf.keras.optimizers.legacy.RMSprop(**kwargs))
return optimizer_tests.WrappedTFOptimizer(
tf.keras.optimizers.RMSprop(**kwargs))


if __name__ == "__main__":
Expand Down
5 changes: 4 additions & 1 deletion sonnet/src/optimizers/sgd_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,10 @@ class ReferenceSGDTest(SGDTest):
def make_optimizer(self, *args, **kwargs):
if "learning_rate" not in kwargs:
kwargs["learning_rate"] = 3.
return optimizer_tests.WrappedTFOptimizer(tf.optimizers.SGD(**kwargs))
if hasattr(tf.keras.optimizers, "legacy"):
return optimizer_tests.WrappedTFOptimizer(
tf.keras.optimizers.legacy.SGD(**kwargs))
return optimizer_tests.WrappedTFOptimizer(tf.keras.optimizers.SGD(**kwargs))


if __name__ == "__main__":
Expand Down

0 comments on commit c92ac8e

Please sign in to comment.