diff --git a/horovod/common/util.py b/horovod/common/util.py index 8558828562..c1cd58d378 100644 --- a/horovod/common/util.py +++ b/horovod/common/util.py @@ -265,3 +265,24 @@ def is_iterable(x): except TypeError: return False return True + + +@_cache +def is_version_greater_equal_than(ver, target): + from distutils.version import LooseVersion, StrictVersion + if any([not isinstance(_str, str) for _str in (ver, target)]): + raise ValueError("This function only accepts string arguments. \n" + "Received:\n" + "\t- ver (type {type_ver}: {val_ver})" + "\t- target (type {type_target}: {val_target})".format( + type_ver=(type(ver)), + val_ver=ver, + type_target=(type(target)), + val_target=target, + )) + + if len(target.split(".")) != 3: + raise ValueError("We only accepts target version values in the form " + "of: major.minor.patch. Received: {}".format(target)) + + return LooseVersion(ver) >= LooseVersion(target) diff --git a/horovod/spark/keras/tensorflow.py b/horovod/spark/keras/tensorflow.py index 3be6fc4d35..2f63724e5f 100644 --- a/horovod/spark/keras/tensorflow.py +++ b/horovod/spark/keras/tensorflow.py @@ -15,8 +15,17 @@ import json -from tensorflow.python.keras import backend as K -from tensorflow.python.keras import optimizers +import tensorflow as tf + +from horovod.common.util import is_version_greater_equal_than + +if is_version_greater_equal_than(tf.__version__, "2.5.0"): + from keras import backend as K + from keras import optimizers +else: + from tensorflow.python.keras import backend as K + from tensorflow.python.keras import optimizers + from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import serialization diff --git a/horovod/tensorflow/keras/__init__.py b/horovod/tensorflow/keras/__init__.py index c381b3edbd..6d18bcf746 100644 --- a/horovod/tensorflow/keras/__init__.py +++ b/horovod/tensorflow/keras/__init__.py @@ -19,7 +19,13 @@ import tensorflow as tf from tensorflow import keras -from tensorflow.python.keras import backend as K + +from horovod.common.util import is_version_greater_equal_than + +if is_version_greater_equal_than(tf.__version__, "2.5.0"): + from keras import backend as K +else: + from tensorflow.python.keras import backend as K from horovod.tensorflow import init from horovod.tensorflow import shutdown @@ -247,4 +253,3 @@ def load_model(filepath, custom_optimizers=None, custom_objects=None, compressio def wrap_optimizer(cls): return lambda **kwargs: DistributedOptimizer(cls(**kwargs), compression=compression) return _impl.load_model(keras, wrap_optimizer, _OPTIMIZER_MODULES, filepath, custom_optimizers, custom_objects) - diff --git a/horovod/tensorflow/keras/callbacks.py b/horovod/tensorflow/keras/callbacks.py index b91f367088..5e6dbeb1c0 100644 --- a/horovod/tensorflow/keras/callbacks.py +++ b/horovod/tensorflow/keras/callbacks.py @@ -13,8 +13,16 @@ # limitations under the License. # ============================================================================== + +import tensorflow as tf from tensorflow import keras -from tensorflow.python.keras import backend as K + +from horovod.common.util import is_version_greater_equal_than + +if is_version_greater_equal_than(tf.__version__, "2.5.0"): + from keras import backend as K +else: + from tensorflow.python.keras import backend as K from horovod._keras import callbacks as _impl diff --git a/test/parallel/test_tensorflow2_keras.py b/test/parallel/test_tensorflow2_keras.py index df61e2e521..6836c43185 100644 --- a/test/parallel/test_tensorflow2_keras.py +++ b/test/parallel/test_tensorflow2_keras.py @@ -26,7 +26,13 @@ import pytest from tensorflow import keras -from tensorflow.python.keras.optimizer_v2 import optimizer_v2 + +from horovod.common.util import is_version_greater_equal_than + +if is_version_greater_equal_than(tf.__version__, "2.5.0"): + from keras.optimizer_v2 import optimizer_v2 +else: + from tensorflow.python.keras.optimizer_v2 import optimizer_v2 import horovod.tensorflow.keras as hvd diff --git a/test/parallel/test_tensorflow_keras.py b/test/parallel/test_tensorflow_keras.py index 91e5e78321..763892c902 100644 --- a/test/parallel/test_tensorflow_keras.py +++ b/test/parallel/test_tensorflow_keras.py @@ -25,8 +25,15 @@ from distutils.version import LooseVersion from tensorflow import keras -from tensorflow.python.keras import backend as K -from tensorflow.python.keras.optimizer_v2 import optimizer_v2 + +from horovod.common.util import is_version_greater_equal_than + +if is_version_greater_equal_than(tf.__version__, "2.5.0"): + from keras import backend as K + from keras.optimizer_v2 import optimizer_v2 +else: + from tensorflow.python.keras import backend as K + from tensorflow.python.keras.optimizer_v2 import optimizer_v2 import horovod.tensorflow.keras as hvd