Skip to content

Commit

Permalink
Fix Keras Import to load upstream Keras
Browse files Browse the repository at this point in the history
  • Loading branch information
DEKHTIARJonathan authored and chongxiaoc committed Feb 7, 2022
1 parent 046c071 commit a7849c7
Show file tree
Hide file tree
Showing 6 changed files with 64 additions and 8 deletions.
21 changes: 21 additions & 0 deletions horovod/common/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,3 +265,24 @@ def is_iterable(x):
except TypeError:
return False
return True


@_cache
def is_version_greater_equal_than(ver, target):
from distutils.version import LooseVersion, StrictVersion
if any([not isinstance(_str, str) for _str in (ver, target)]):
raise ValueError("This function only accepts string arguments. \n"
"Received:\n"
"\t- ver (type {type_ver}: {val_ver})"
"\t- target (type {type_target}: {val_target})".format(
type_ver=(type(ver)),
val_ver=ver,
type_target=(type(target)),
val_target=target,
))

if len(target.split(".")) != 3:
raise ValueError("We only accepts target version values in the form "
"of: major.minor.patch. Received: {}".format(target))

return LooseVersion(ver) >= LooseVersion(target)
13 changes: 11 additions & 2 deletions horovod/spark/keras/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,17 @@

import json

from tensorflow.python.keras import backend as K
from tensorflow.python.keras import optimizers
import tensorflow as tf

from horovod.common.util import is_version_greater_equal_than

if is_version_greater_equal_than(tf.__version__, "2.5.0"):
from keras import backend as K
from keras import optimizers
else:
from tensorflow.python.keras import backend as K
from tensorflow.python.keras import optimizers

from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import serialization

Expand Down
9 changes: 7 additions & 2 deletions horovod/tensorflow/keras/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,13 @@
import tensorflow as tf

from tensorflow import keras
from tensorflow.python.keras import backend as K

from horovod.common.util import is_version_greater_equal_than

if is_version_greater_equal_than(tf.__version__, "2.5.0"):
from keras import backend as K
else:
from tensorflow.python.keras import backend as K

from horovod.tensorflow import init
from horovod.tensorflow import shutdown
Expand Down Expand Up @@ -247,4 +253,3 @@ def load_model(filepath, custom_optimizers=None, custom_objects=None, compressio
def wrap_optimizer(cls):
return lambda **kwargs: DistributedOptimizer(cls(**kwargs), compression=compression)
return _impl.load_model(keras, wrap_optimizer, _OPTIMIZER_MODULES, filepath, custom_optimizers, custom_objects)

10 changes: 9 additions & 1 deletion horovod/tensorflow/keras/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,16 @@
# limitations under the License.
# ==============================================================================


import tensorflow as tf
from tensorflow import keras
from tensorflow.python.keras import backend as K

from horovod.common.util import is_version_greater_equal_than

if is_version_greater_equal_than(tf.__version__, "2.5.0"):
from keras import backend as K
else:
from tensorflow.python.keras import backend as K

from horovod._keras import callbacks as _impl

Expand Down
8 changes: 7 additions & 1 deletion test/parallel/test_tensorflow2_keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,13 @@
import pytest

from tensorflow import keras
from tensorflow.python.keras.optimizer_v2 import optimizer_v2

from horovod.common.util import is_version_greater_equal_than

if is_version_greater_equal_than(tf.__version__, "2.5.0"):
from keras.optimizer_v2 import optimizer_v2
else:
from tensorflow.python.keras.optimizer_v2 import optimizer_v2

import horovod.tensorflow.keras as hvd

Expand Down
11 changes: 9 additions & 2 deletions test/parallel/test_tensorflow_keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,15 @@

from distutils.version import LooseVersion
from tensorflow import keras
from tensorflow.python.keras import backend as K
from tensorflow.python.keras.optimizer_v2 import optimizer_v2

from horovod.common.util import is_version_greater_equal_than

if is_version_greater_equal_than(tf.__version__, "2.5.0"):
from keras import backend as K
from keras.optimizer_v2 import optimizer_v2
else:
from tensorflow.python.keras import backend as K
from tensorflow.python.keras.optimizer_v2 import optimizer_v2

import horovod.tensorflow.keras as hvd

Expand Down

0 comments on commit a7849c7

Please sign in to comment.