diff --git a/test/parallel/test_tensorflow2_keras.py b/test/parallel/test_tensorflow2_keras.py index d15fc3ad8d..7647b6916d 100644 --- a/test/parallel/test_tensorflow2_keras.py +++ b/test/parallel/test_tensorflow2_keras.py @@ -30,7 +30,10 @@ from horovod.common.util import is_version_greater_equal_than if is_version_greater_equal_than(tf.__version__, "2.6.0"): - from keras.optimizer_v2 import optimizer_v2 + if LooseVersion(keras.__version__) < LooseVersion("2.9.0"): + from keras.optimizer_v2 import optimizer_v2 + else: + from keras.optimizers.optimizer_v2 import optimizer_v2 else: from tensorflow.python.keras.optimizer_v2 import optimizer_v2