Skip to content

Commit

Permalink
Fix
Browse files Browse the repository at this point in the history
  • Loading branch information
dbczumar committed Sep 3, 2021
1 parent 50cf06c commit c334b9b
Showing 1 changed file with 45 additions and 17 deletions.
62 changes: 45 additions & 17 deletions mlflow/tracking/fluent.py
Original file line number Diff line number Diff line change
Expand Up @@ -1469,33 +1469,61 @@ def setup_autologging(module):
):
register_post_import_hook(setup_autologging, module, overwrite=True)

FULLY_IMPORTED_KERAS = False

def conditionally_set_up_keras_autologging(keras_module):
# NB: Keras unconditionally depends on TensorFlow beginning with Version 2.6.0, and
# many classes defined in the `keras` module are aliases of classes in the `tf.keras`
# module. Accordingly, TensorFlow autologging serves as a replacement for Keras autologging
# in Keras >= 2.6.0
nonlocal FULLY_IMPORTED_KERAS
FULLY_IMPORTED_KERAS = True

if Version(keras_module.__version__) >= Version("2.6.0"):
return
# NB: Keras unconditionally depends on TensorFlow beginning with Version 2.6.0, and
# many classes defined in the `keras` module are aliases of classes in the `tf.keras`
# module. Accordingly, TensorFlow autologging serves as a replacement for Keras
# autologging in Keras >= 2.6.0
try:
import tensorflow
setup_autologging(tensorflow)
except Exception as e:
_logger.debug(
"Failed to set up TensorFlow autologging for tf.keras models upon"
" Keras library import: %s", str(e)
)

setup_autologging(keras_module)
else:
setup_autologging(keras_module)

register_post_import_hook(conditionally_set_up_keras_autologging, "keras", overwrite=True)

def set_up_tensorflow_autologging(tensorflow_module):
# In version 2.6.0 and above of the Keras library, importing Keras unconditionally imports
# TensorFlow. When TensorFlow is imported during the Keras import procedure, the
# `tensorflow.keras` module, which is used during the setup procedure for TensorFlow
# autologging, has not yet been imported. In this circumstance, attempting to import
# `tensorflow.keras` during the TensorFlow import hook will provide the wrong module;
# the deprecated / legacy `tf.python.keras` will be provided, rather than the expected
# `keras.api._v2.keras` module. To correct this issue before enabling TensorFlow
# autologging, we manually assign `tensorflow.keras` to the correct `keras.api._v2.keras`
# module
import sys
import importlib

sys.modules['tensorflow.keras'] = importlib.import_module('keras.api._v2.keras')
setup_autologging(tensorflow_module)
nonlocal FULLY_IMPORTED_KERAS
if "keras" in sys.modules and not FULLY_IMPORTED_KERAS:
# In Keras >= 2.6.0, importing Keras imports the TensorFlow library, which can
# trigger this autologging import hook for TensorFlow before the entire Keras import
# procedure is completed. Attempting to set up autologging before the Keras import
# procedure has completed will result in a failure due to the unavailability of
# certain modules. In this case, we terminate the TensorFlow autologging import hook
# and rely on the Keras autologging import hook to successfully set up TensorFlow
# autologging for tf.keras models once the Keras import procedure has completed
return

if Version(tensorflow_module.__version__) >= Version("2.6.0"):
# In certain Keras versions >= 2.6.0, setting up a post-import hook for the `keras`
# module causes the `tensorflow.keras` module to incorrectly refer to the deprecated
# / legacy `tf.python.keras` module, rather than the expected `keras.api._v2.keras`
# module. To correct this issue, we manually assign `tensorflow.keras` to the correct
# `keras.api._v2.keras` module
try:
sys.modules['tensorflow.keras'] = importlib.import_module('keras.api._v2.keras')
except Exception as e:
_logger.debug(
"Failed to assign correct module for `tensorflow.keras`: %s", str(e)
)

if autologging_is_disabled("tensorflow"):
setup_autologging(tensorflow_module)

register_post_import_hook(set_up_tensorflow_autologging, "tensorflow", overwrite=True)

Expand Down

0 comments on commit c334b9b

Please sign in to comment.