-
Notifications
You must be signed in to change notification settings - Fork 4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix autologging compatibility with Keras >= 2.6.0 #4766
Changes from 15 commits
9c02a80
bbcc82b
7ed5725
f51cb0b
0983d71
fd70978
17f53bb
4a50aab
831ebc2
2ce92c6
edf8760
8f5bce4
31a5c02
6b646f0
d5b45b0
5d06db4
4ec1bf7
b0ce97c
c86883f
f195d2f
b65a2f2
d9cf2ca
6f2aac5
076f650
0c88bcf
29d2e99
a775965
9e32e6a
6219c3a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,6 +8,7 @@ | |
import time | ||
import logging | ||
import inspect | ||
from packaging.version import Version | ||
from typing import Any, Dict, List, Optional, Union, TYPE_CHECKING | ||
|
||
from mlflow.entities import Experiment, Run, RunInfo, RunStatus, Param, RunTag, Metric, ViewType | ||
|
@@ -1463,9 +1464,62 @@ def setup_autologging(module): | |
# for each autolog library (except pyspark), register a post-import hook. | ||
# this way, we do not send any errors to the user until we know they are using the library. | ||
# the post-import hook also retroactively activates for previously-imported libraries. | ||
for module in list(set(LIBRARY_TO_AUTOLOG_FN.keys()) - set(["pyspark", "pyspark.ml"])): | ||
for module in list( | ||
set(LIBRARY_TO_AUTOLOG_FN.keys()) - set(["tensorflow", "keras", "pyspark", "pyspark.ml"]) | ||
): | ||
register_post_import_hook(setup_autologging, module, overwrite=True) | ||
|
||
FULLY_IMPORTED_KERAS = False | ||
|
||
def conditionally_set_up_keras_autologging(keras_module): | ||
nonlocal FULLY_IMPORTED_KERAS | ||
FULLY_IMPORTED_KERAS = True | ||
|
||
if Version(keras_module.__version__) >= Version("2.6.0"): | ||
# NB: Keras unconditionally depends on TensorFlow beginning with Version 2.6.0, and | ||
# many classes defined in the `keras` module are aliases of classes in the `tf.keras` | ||
# module. Accordingly, TensorFlow autologging serves as a replacement for Keras | ||
# autologging in Keras >= 2.6.0 | ||
try: | ||
import tensorflow | ||
|
||
setup_autologging(tensorflow) | ||
except Exception as e: | ||
_logger.debug( | ||
"Failed to set up TensorFlow autologging for tf.keras models upon" | ||
" Keras library import: %s", | ||
str(e), | ||
) | ||
|
||
else: | ||
setup_autologging(keras_module) | ||
|
||
register_post_import_hook(conditionally_set_up_keras_autologging, "keras", overwrite=True) | ||
|
||
def set_up_tensorflow_autologging(tensorflow_module): | ||
import sys | ||
|
||
nonlocal FULLY_IMPORTED_KERAS | ||
if "keras" in sys.modules and not FULLY_IMPORTED_KERAS: | ||
# In Keras >= 2.6.0, importing Keras imports the TensorFlow library, which can | ||
# trigger this autologging import hook for TensorFlow before the entire Keras import | ||
# procedure is completed. Attempting to set up autologging before the Keras import | ||
# procedure has completed will result in a failure due to the unavailability of | ||
# certain modules. In this case, we terminate the TensorFlow autologging import hook | ||
# and rely on the Keras autologging import hook to successfully set up TensorFlow | ||
# autologging for tf.keras models once the Keras import procedure has completed | ||
return | ||
|
||
# By design, in Keras >= 2.6.0, Keras needs to enable tensorflow autologging so that | ||
# tf.keras models always use tensorflow autologging, rather than vanilla keras autologging. | ||
# As a result, Keras autologging must call `mlflow.tensorflow.autolog()` in Keras >= 2.6.0. | ||
# Accordingly, we insert this check to ensure that importing tensorflow, which may import | ||
# keras, does not enable tensorflow autologging twice. | ||
if autologging_is_disabled("tensorflow"): | ||
setup_autologging(tensorflow_module) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. By design, importing keras enables tensorflow autologging. We insert this check to ensure that importing tensorflow, which imports keras, does not enable tensorflow autologging twice. Keras needs to enable tensorflow autologging so that |
||
|
||
register_post_import_hook(set_up_tensorflow_autologging, "tensorflow", overwrite=True) | ||
|
||
# for pyspark, we activate autologging immediately, without waiting for a module import. | ||
# this is because on Databricks a SparkSession already exists and the user can directly | ||
# interact with it, and this activity should be logged. | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -247,7 +247,7 @@ def __init__(self): | |
|
||
@synchronized(_post_import_hooks_lock) | ||
@synchronized(_import_error_hooks_lock) | ||
def find_module(self, fullname, path=None): | ||
def find_spec(self, fullname, path, target=None): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @dbczumar Confirmed from mlflow.utils.import_hooks import register_post_import_hook
register_post_import_hook(lambda x: x, "keras")
import tensorflow.keras
print(tensorflow.keras) output:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Woohoo! Awesome RCA, Haru! Really spectacular! |
||
# If the module being imported is not one we have registered | ||
# import hooks for, we can return immediately. We will | ||
# take no further part in the importing of this module. | ||
|
@@ -270,24 +270,13 @@ def find_module(self, fullname, path=None): | |
# Now call back into the import system again. | ||
|
||
try: | ||
# For Python 3 we need to use find_spec().loader | ||
# from the importlib.util module. It doesn't actually | ||
# import the target module and only finds the | ||
# loader. If a loader is found, we need to return | ||
# our own loader which will then in turn call the | ||
# real loader to import the module and invoke the | ||
# post import hooks. | ||
try: | ||
import importlib.util | ||
|
||
loader = importlib.util.find_spec(fullname).loader | ||
# If an ImportError (or AttributeError) is encountered while finding the module, | ||
# notify the hooks for import errors | ||
except (ImportError, AttributeError): | ||
notify_module_import_error(fullname) | ||
loader = importlib.find_loader(fullname, path) # pylint: disable=deprecated-method | ||
if loader: | ||
return _ImportHookChainedLoader(loader) | ||
import importlib.util | ||
|
||
spec = importlib.util.find_spec(fullname) | ||
# Replace the module spec's loader with a wrapped version that executes import | ||
# hooks when the module is loaded | ||
spec.loader = _ImportHookChainedLoader(spec.loader) | ||
return spec | ||
finally: | ||
del self.in_progress[fullname] | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Keras needs to enable tensorflow autologging so that
import keras
enables tensorflow autologging andimport tensorflow
only enables tensorflow autologging and not keras autologging.