Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix autologging compatibility with Keras >= 2.6.0 #4766

Merged
merged 29 commits into from
Sep 3, 2021
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
54 changes: 53 additions & 1 deletion mlflow/tracking/fluent.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import time
import logging
import inspect
from packaging.version import Version
from typing import Any, Dict, List, Optional, Union, TYPE_CHECKING

from mlflow.entities import Experiment, Run, RunInfo, RunStatus, Param, RunTag, Metric, ViewType
Expand Down Expand Up @@ -1463,9 +1464,60 @@ def setup_autologging(module):
# for each autolog library (except pyspark), register a post-import hook.
# this way, we do not send any errors to the user until we know they are using the library.
# the post-import hook also retroactively activates for previously-imported libraries.
for module in list(set(LIBRARY_TO_AUTOLOG_FN.keys()) - set(["pyspark", "pyspark.ml"])):
for module in list(
set(LIBRARY_TO_AUTOLOG_FN.keys()) - set(["tensorflow", "keras", "pyspark", "pyspark.ml"])
):
register_post_import_hook(setup_autologging, module, overwrite=True)

FULLY_IMPORTED_KERAS = False

def conditionally_set_up_keras_autologging(keras_module):
nonlocal FULLY_IMPORTED_KERAS
FULLY_IMPORTED_KERAS = True

if Version(keras_module.__version__) >= Version("2.6.0"):
# NB: Keras unconditionally depends on TensorFlow beginning with Version 2.6.0, and
# many classes defined in the `keras` module are aliases of classes in the `tf.keras`
# module. Accordingly, TensorFlow autologging serves as a replacement for Keras
# autologging in Keras >= 2.6.0
try:
import tensorflow
setup_autologging(tensorflow)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Keras needs to enable tensorflow autologging so that import keras enables tensorflow autologging and import tensorflow only enables tensorflow autologging and not keras autologging.

except Exception as e:
_logger.debug(
"Failed to set up TensorFlow autologging for tf.keras models upon"
" Keras library import: %s", str(e)
)

else:
setup_autologging(keras_module)

register_post_import_hook(conditionally_set_up_keras_autologging, "keras", overwrite=True)

def set_up_tensorflow_autologging(tensorflow_module):
import sys

nonlocal FULLY_IMPORTED_KERAS
if "keras" in sys.modules and not FULLY_IMPORTED_KERAS:
# In Keras >= 2.6.0, importing Keras imports the TensorFlow library, which can
# trigger this autologging import hook for TensorFlow before the entire Keras import
# procedure is completed. Attempting to set up autologging before the Keras import
# procedure has completed will result in a failure due to the unavailability of
# certain modules. In this case, we terminate the TensorFlow autologging import hook
# and rely on the Keras autologging import hook to successfully set up TensorFlow
# autologging for tf.keras models once the Keras import procedure has completed
return

# By design, in Keras >= 2.6.0, Keras needs to enable tensorflow autologging so that
# tf.keras models always use tensorflow autologging, rather than vanilla keras autologging.
# As a result, Keras autologging must call `mlflow.tensorflow.autolog()` in Keras >= 2.6.0.
# Accordingly, we insert this check to ensure that importing tensorflow, which may import
# keras, does not enable tensorflow autologging twice.
if autologging_is_disabled("tensorflow"):
setup_autologging(tensorflow_module)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By design, importing keras enables tensorflow autologging. We insert this check to ensure that importing tensorflow, which imports keras, does not enable tensorflow autologging twice.

Keras needs to enable tensorflow autologging so that import keras enables tensorflow autologging and import tensorflow only enables tensorflow autologging and not keras autologging.


register_post_import_hook(set_up_tensorflow_autologging, "tensorflow", overwrite=True)

# for pyspark, we activate autologging immediately, without waiting for a module import.
# this is because on Databricks a SparkSession already exists and the user can directly
# interact with it, and this activity should be logged.
Expand Down
27 changes: 8 additions & 19 deletions mlflow/utils/import_hooks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ def __init__(self):

@synchronized(_post_import_hooks_lock)
@synchronized(_import_error_hooks_lock)
def find_module(self, fullname, path=None):
def find_spec(self, fullname, path, target=None):
Copy link
Member

@harupy harupy Sep 3, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dbczumar Confirmed tensorflow.keras now refers to site-packages/keras/api/_v2/keras/__init__.py

from mlflow.utils.import_hooks import register_post_import_hook

register_post_import_hook(lambda x: x, "keras")

import tensorflow.keras

print(tensorflow.keras)

output:

<module 'tensorflow.keras' from '/Users/harutakakawamura/.pyenv/versions/miniconda3-4.7.12/envs/mlflow-dev-env/lib/python3.7/site-packages/keras/api/_v2/keras/__init__.py'>

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Woohoo! Awesome RCA, Haru! Really spectacular!

# If the module being imported is not one we have registered
# import hooks for, we can return immediately. We will
# take no further part in the importing of this module.
Expand All @@ -270,24 +270,13 @@ def find_module(self, fullname, path=None):
# Now call back into the import system again.

try:
# For Python 3 we need to use find_spec().loader
# from the importlib.util module. It doesn't actually
# import the target module and only finds the
# loader. If a loader is found, we need to return
# our own loader which will then in turn call the
# real loader to import the module and invoke the
# post import hooks.
try:
import importlib.util

loader = importlib.util.find_spec(fullname).loader
# If an ImportError (or AttributeError) is encountered while finding the module,
# notify the hooks for import errors
except (ImportError, AttributeError):
notify_module_import_error(fullname)
loader = importlib.find_loader(fullname, path) # pylint: disable=deprecated-method
if loader:
return _ImportHookChainedLoader(loader)
import importlib.util

spec = importlib.util.find_spec(fullname)
# Replace the module spec's loader with a wrapped version that executes import
# hooks when the module is loaded
spec.loader = _ImportHookChainedLoader(spec.loader)
return spec
finally:
del self.in_progress[fullname]

Expand Down