Skip to content

Commit

Permalink
Fix Keras imports for optimizer algorithms
Browse files Browse the repository at this point in the history
  • Loading branch information
master committed Feb 15, 2022
1 parent 90c432a commit 4814b0c
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 6 deletions.
4 changes: 2 additions & 2 deletions tensorflow_riemopt/optimizers/constrained_rmsprop.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,19 @@
from tensorflow.python.eager import def_function
from tensorflow.python.framework import ops
from tensorflow.python.keras import backend_config
from tensorflow.python.keras.optimizer_v2 import optimizer_v2
from tensorflow.python.keras.utils import generic_utils
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.training import gen_training_ops
from keras.optimizer_v2.optimizer_v2 import OptimizerV2

from tensorflow_riemopt.variable import get_manifold


@generic_utils.register_keras_serializable(name="ConstrainedRMSprop")
class ConstrainedRMSprop(optimizer_v2.OptimizerV2):
class ConstrainedRMSprop(OptimizerV2):
"""Optimizer that implements the RMSprop algorithm."""

_HAS_AGGREGATE_GRAD = True
Expand Down
4 changes: 2 additions & 2 deletions tensorflow_riemopt/optimizers/riemannian_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,19 @@
from tensorflow.python.eager import def_function
from tensorflow.python.framework import ops
from tensorflow.python.keras import backend_config
from tensorflow.python.keras.optimizer_v2 import optimizer_v2
from tensorflow.python.keras.utils import generic_utils
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.training import gen_training_ops
from keras.optimizer_v2.optimizer_v2 import OptimizerV2

from tensorflow_riemopt.variable import get_manifold


@generic_utils.register_keras_serializable(name="RiemannianAdam")
class RiemannianAdam(optimizer_v2.OptimizerV2):
class RiemannianAdam(OptimizerV2):
"""Optimizer that implements the Riemannian Adam algorithm."""

_HAS_AGGREGATE_GRAD = True
Expand Down
4 changes: 2 additions & 2 deletions tensorflow_riemopt/optimizers/riemannian_gradient_descent.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,19 @@
from tensorflow.python.eager import def_function
from tensorflow.python.framework import ops
from tensorflow.python.keras import backend_config
from tensorflow.python.keras.optimizer_v2 import optimizer_v2
from tensorflow.python.keras.utils import generic_utils
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.training import gen_training_ops
from keras.optimizer_v2.optimizer_v2 import OptimizerV2

from tensorflow_riemopt.variable import get_manifold


@generic_utils.register_keras_serializable(name="RiemannianSGD")
class RiemannianSGD(optimizer_v2.OptimizerV2):
class RiemannianSGD(OptimizerV2):
"""Optimizer that implements the Riemannian SGD algorithm."""

_HAS_AGGREGATE_GRAD = True
Expand Down

0 comments on commit 4814b0c

Please sign in to comment.