-
Notifications
You must be signed in to change notification settings - Fork 4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix autologging compatibility with Keras >= 2.6.0 #4766
Changes from all commits
9c02a80
bbcc82b
7ed5725
f51cb0b
0983d71
fd70978
17f53bb
4a50aab
831ebc2
2ce92c6
edf8760
8f5bce4
31a5c02
6b646f0
d5b45b0
5d06db4
4ec1bf7
b0ce97c
c86883f
f195d2f
b65a2f2
d9cf2ca
6f2aac5
076f650
0c88bcf
29d2e99
a775965
9e32e6a
6219c3a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,6 +8,7 @@ | |
import time | ||
import logging | ||
import inspect | ||
from packaging.version import Version | ||
from typing import Any, Dict, List, Optional, Union, TYPE_CHECKING | ||
|
||
from mlflow.entities import Experiment, Run, RunInfo, RunStatus, Param, RunTag, Metric, ViewType | ||
|
@@ -1463,9 +1464,64 @@ def setup_autologging(module): | |
# for each autolog library (except pyspark), register a post-import hook. | ||
# this way, we do not send any errors to the user until we know they are using the library. | ||
# the post-import hook also retroactively activates for previously-imported libraries. | ||
for module in list(set(LIBRARY_TO_AUTOLOG_FN.keys()) - set(["pyspark", "pyspark.ml"])): | ||
for module in list( | ||
set(LIBRARY_TO_AUTOLOG_FN.keys()) - set(["tensorflow", "keras", "pyspark", "pyspark.ml"]) | ||
): | ||
register_post_import_hook(setup_autologging, module, overwrite=True) | ||
|
||
FULLY_IMPORTED_KERAS = False | ||
TF_AUTOLOG_SETUP_CALLED = False | ||
|
||
def conditionally_set_up_keras_autologging(keras_module): | ||
nonlocal FULLY_IMPORTED_KERAS, TF_AUTOLOG_SETUP_CALLED | ||
FULLY_IMPORTED_KERAS = True | ||
|
||
if Version(keras_module.__version__) >= Version("2.6.0"): | ||
# NB: Keras unconditionally depends on TensorFlow beginning with Version 2.6.0, and | ||
# many classes defined in the `keras` module are aliases of classes in the `tf.keras` | ||
# module. Accordingly, TensorFlow autologging serves as a replacement for Keras | ||
# autologging in Keras >= 2.6.0 | ||
try: | ||
import tensorflow | ||
|
||
setup_autologging(tensorflow) | ||
TF_AUTOLOG_SETUP_CALLED = True | ||
except Exception as e: | ||
_logger.debug( | ||
"Failed to set up TensorFlow autologging for tf.keras models upon" | ||
" Keras library import: %s", | ||
str(e), | ||
) | ||
raise | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I add |
||
else: | ||
setup_autologging(keras_module) | ||
|
||
register_post_import_hook(conditionally_set_up_keras_autologging, "keras", overwrite=True) | ||
|
||
def set_up_tensorflow_autologging(tensorflow_module): | ||
import sys | ||
|
||
nonlocal FULLY_IMPORTED_KERAS, TF_AUTOLOG_SETUP_CALLED | ||
if "keras" in sys.modules and not FULLY_IMPORTED_KERAS: | ||
# In Keras >= 2.6.0, importing Keras imports the TensorFlow library, which can | ||
# trigger this autologging import hook for TensorFlow before the entire Keras import | ||
# procedure is completed. Attempting to set up autologging before the Keras import | ||
# procedure has completed will result in a failure due to the unavailability of | ||
# certain modules. In this case, we terminate the TensorFlow autologging import hook | ||
# and rely on the Keras autologging import hook to successfully set up TensorFlow | ||
# autologging for tf.keras models once the Keras import procedure has completed | ||
return | ||
|
||
# By design, in Keras >= 2.6.0, Keras needs to enable tensorflow autologging so that | ||
# tf.keras models always use tensorflow autologging, rather than vanilla keras autologging. | ||
# As a result, Keras autologging must call `mlflow.tensorflow.autolog()` in Keras >= 2.6.0. | ||
# Accordingly, we insert this check to ensure that importing tensorflow, which may import | ||
# keras, does not enable tensorflow autologging twice. | ||
if not TF_AUTOLOG_SETUP_CALLED: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Don't use |
||
setup_autologging(tensorflow_module) | ||
|
||
register_post_import_hook(set_up_tensorflow_autologging, "tensorflow", overwrite=True) | ||
|
||
# for pyspark, we activate autologging immediately, without waiting for a module import. | ||
# this is because on Databricks a SparkSession already exists and the user can directly | ||
# interact with it, and this activity should be logged. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Keras needs to enable tensorflow autologging so that
import keras
enables tensorflow autologging andimport tensorflow
only enables tensorflow autologging and not keras autologging.