Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix autologging compatibility with Keras >= 2.6.0 #4766

Merged
merged 29 commits into from
Sep 3, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
10 changes: 10 additions & 0 deletions mlflow/keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -673,6 +673,16 @@ def autolog(

_raise_deprecation_warning()

if Version(keras.__version__) >= Version("2.6.0"):
warnings.warn(
(
"Autologging support for keras >= 2.6.0 has been deprecated and will be removed in "
"a future MLflow release. Use `mlflow.tensorflow.autolog()` instead."
),
FutureWarning,
stacklevel=2,
)

def getKerasCallback(metrics_logger):
class __MLflowKerasCallback(keras.callbacks.Callback, metaclass=ExceptionSafeClass):
"""
Expand Down
58 changes: 57 additions & 1 deletion mlflow/tracking/fluent.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import time
import logging
import inspect
from packaging.version import Version
from typing import Any, Dict, List, Optional, Union, TYPE_CHECKING

from mlflow.entities import Experiment, Run, RunInfo, RunStatus, Param, RunTag, Metric, ViewType
Expand Down Expand Up @@ -1463,9 +1464,64 @@ def setup_autologging(module):
# for each autolog library (except pyspark), register a post-import hook.
# this way, we do not send any errors to the user until we know they are using the library.
# the post-import hook also retroactively activates for previously-imported libraries.
for module in list(set(LIBRARY_TO_AUTOLOG_FN.keys()) - set(["pyspark", "pyspark.ml"])):
for module in list(
set(LIBRARY_TO_AUTOLOG_FN.keys()) - set(["tensorflow", "keras", "pyspark", "pyspark.ml"])
):
register_post_import_hook(setup_autologging, module, overwrite=True)

FULLY_IMPORTED_KERAS = False
TF_AUTOLOG_SETUP_CALLED = False

def conditionally_set_up_keras_autologging(keras_module):
nonlocal FULLY_IMPORTED_KERAS, TF_AUTOLOG_SETUP_CALLED
FULLY_IMPORTED_KERAS = True

if Version(keras_module.__version__) >= Version("2.6.0"):
# NB: Keras unconditionally depends on TensorFlow beginning with Version 2.6.0, and
# many classes defined in the `keras` module are aliases of classes in the `tf.keras`
# module. Accordingly, TensorFlow autologging serves as a replacement for Keras
# autologging in Keras >= 2.6.0
try:
import tensorflow

setup_autologging(tensorflow)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Keras needs to enable tensorflow autologging so that import keras enables tensorflow autologging and import tensorflow only enables tensorflow autologging and not keras autologging.

TF_AUTOLOG_SETUP_CALLED = True
except Exception as e:
_logger.debug(
"Failed to set up TensorFlow autologging for tf.keras models upon"
" Keras library import: %s",
str(e),
)
raise
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I add raise here, otherwise it will break the test_universal_autolog_throws_if_specific_autolog_throws_in_test_mode test.

else:
setup_autologging(keras_module)

register_post_import_hook(conditionally_set_up_keras_autologging, "keras", overwrite=True)

def set_up_tensorflow_autologging(tensorflow_module):
import sys

nonlocal FULLY_IMPORTED_KERAS, TF_AUTOLOG_SETUP_CALLED
if "keras" in sys.modules and not FULLY_IMPORTED_KERAS:
# In Keras >= 2.6.0, importing Keras imports the TensorFlow library, which can
# trigger this autologging import hook for TensorFlow before the entire Keras import
# procedure is completed. Attempting to set up autologging before the Keras import
# procedure has completed will result in a failure due to the unavailability of
# certain modules. In this case, we terminate the TensorFlow autologging import hook
# and rely on the Keras autologging import hook to successfully set up TensorFlow
# autologging for tf.keras models once the Keras import procedure has completed
return

# By design, in Keras >= 2.6.0, Keras needs to enable tensorflow autologging so that
# tf.keras models always use tensorflow autologging, rather than vanilla keras autologging.
# As a result, Keras autologging must call `mlflow.tensorflow.autolog()` in Keras >= 2.6.0.
# Accordingly, we insert this check to ensure that importing tensorflow, which may import
# keras, does not enable tensorflow autologging twice.
if not TF_AUTOLOG_SETUP_CALLED:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't use if autologging_is_disabled(...) here, otherwise it will break test test_autolog_success_message_obeys_disabled

setup_autologging(tensorflow_module)

register_post_import_hook(set_up_tensorflow_autologging, "tensorflow", overwrite=True)

# for pyspark, we activate autologging immediately, without waiting for a module import.
# this is because on Databricks a SparkSession already exists and the user can directly
# interact with it, and this activity should be logged.
Expand Down
37 changes: 37 additions & 0 deletions mlflow/utils/import_hooks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,43 @@ def find_module(self, fullname, path=None):
finally:
del self.in_progress[fullname]

@synchronized(_post_import_hooks_lock)
@synchronized(_import_error_hooks_lock)
def find_spec(self, fullname, path, target=None): # pylint: disable=unused-argument
# If the module being imported is not one we have registered
# import hooks for, we can return immediately. We will
# take no further part in the importing of this module.

if fullname not in _post_import_hooks and fullname not in _import_error_hooks:
return None

# When we are interested in a specific module, we will call back
# into the import system a second time to defer to the import
# finder that is supposed to handle the importing of the module.
# We set an in progress flag for the target module so that on
# the second time through we don't trigger another call back
# into the import system and cause a infinite loop.

if fullname in self.in_progress:
return None

self.in_progress[fullname] = True

# Now call back into the import system again.

try:
import importlib.util

spec = importlib.util.find_spec(fullname)
# Replace the module spec's loader with a wrapped version that executes import
# hooks when the module is loaded
spec.loader = _ImportHookChainedLoader(spec.loader)
return spec
except (ImportError, AttributeError):
notify_module_import_error(fullname)
finally:
del self.in_progress[fullname]


# Decorator for marking that a function should be called as a post
# import hook when the target module is imported.
Expand Down
118 changes: 117 additions & 1 deletion tests/tensorflow_autolog/test_tensorflow2_autolog.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import collections
import pytest
import sys
from packaging.version import Version

import numpy as np
Expand All @@ -12,7 +13,7 @@
import mlflow
import mlflow.tensorflow
import mlflow.keras
from mlflow.utils.autologging_utils import BatchMetricsLogger
from mlflow.utils.autologging_utils import BatchMetricsLogger, autologging_is_disabled
from unittest.mock import patch

import os
Expand Down Expand Up @@ -53,6 +54,29 @@ def manual_run(request):
mlflow.end_run()


@pytest.fixture
def clear_tf_keras_imports():
"""
Simulates a state where `tensorflow` and `keras` are not imported by removing these
libraries from the `sys.modules` dictionary. This is useful for testing the interaction
between TensorFlow / Keras and the fluent `mlflow.autolog()` API because it will cause import
hooks to be re-triggered upon re-import after `mlflow.autolog()` is enabled.
"""
sys.modules.pop("tensorflow", None)
sys.modules.pop("keras", None)


@pytest.fixture(autouse=True)
def clear_fluent_autologging_import_hooks():
"""
Clears import hooks for MLflow fluent autologging (`mlflow.autolog()`) between tests
to ensure that interactions between fluent autologging and TensorFlow / tf.keras can
be tested successfully
"""
mlflow.utils.import_hooks._post_import_hooks.pop("tensorflow", None)
mlflow.utils.import_hooks._post_import_hooks.pop("keras", None)


def create_tf_keras_model():
model = tf.keras.Sequential()

Expand Down Expand Up @@ -808,3 +832,95 @@ def generator():
assert params["steps_per_epoch"] == "1"
assert "accuracy" in metrics
assert "loss" in metrics


@pytest.mark.large
@pytest.mark.usefixtures("clear_tf_keras_imports")
def test_fluent_autolog_with_tf_keras_logs_expected_content(
random_train_data, random_one_hot_labels
):
"""
Guards against previously-exhibited issues where using the fluent `mlflow.autolog()` API with
`tf.keras` Models did not work due to conflicting patches set by both the
`mlflow.tensorflow.autolog()` and the `mlflow.keras.autolog()` APIs.
"""
mlflow.autolog()

model = create_tf_keras_model()

with mlflow.start_run() as run:
model.fit(random_train_data, random_one_hot_labels, epochs=10)

client = mlflow.tracking.MlflowClient()
run_data = client.get_run(run.info.run_id).data
assert "accuracy" in run_data.metrics
assert "epochs" in run_data.params

artifacts = client.list_artifacts(run.info.run_id)
artifacts = map(lambda x: x.path, artifacts)
assert "model" in artifacts


@pytest.mark.large
@pytest.mark.skipif(
Version(tf.__version__) < Version("2.6.0"),
reason=("TensorFlow only has a hard dependency on Keras in version >= 2.6.0"),
)
@pytest.mark.usefixtures("clear_tf_keras_imports")
def test_fluent_autolog_with_tf_keras_preserves_v2_model_reference():
"""
Verifies that, in TensorFlow >= 2.6.0, `tensorflow.keras.Model` refers to the correct class in
the correct module after `mlflow.autolog()` is called, guarding against previously identified
compatibility issues between recent versions of TensorFlow and MLflow's internal utility for
setting up autologging import hooks.
"""
mlflow.autolog()

import tensorflow.keras
from keras.api._v2.keras import Model as ModelV2

assert tensorflow.keras.Model is ModelV2


@pytest.mark.usefixtures("clear_tf_keras_imports")
def test_import_tensorflow_with_fluent_autolog_enables_tf_autologging():
mlflow.autolog()

import tensorflow # pylint: disable=unused-variable,unused-import,reimported

assert not autologging_is_disabled(mlflow.tensorflow.FLAVOR_NAME)

# NB: For backwards compatibility, fluent autologging enables TensorFlow and
# Keras autologging upon tensorflow import in TensorFlow 2.5.1
if Version(tf.__version__) != Version("2.5.1"):
assert autologging_is_disabled(mlflow.keras.FLAVOR_NAME)


@pytest.mark.large
@pytest.mark.usefixtures("clear_tf_keras_imports")
def test_import_tf_keras_with_fluent_autolog_enables_tf_autologging():
mlflow.autolog()

import tensorflow.keras # pylint: disable=unused-variable,unused-import

assert not autologging_is_disabled(mlflow.tensorflow.FLAVOR_NAME)

# NB: For backwards compatibility, fluent autologging enables TensorFlow and
# Keras autologging upon tf.keras import in TensorFlow 2.5.1
if Version(tf.__version__) != Version("2.5.1"):
assert autologging_is_disabled(mlflow.keras.FLAVOR_NAME)


@pytest.mark.large
@pytest.mark.skipif(
Version(tf.__version__) < Version("2.6.0"),
reason=("TensorFlow autologging is not used for vanilla Keras models in Keras < 2.6.0"),
)
@pytest.mark.usefixtures("clear_tf_keras_imports")
def test_import_keras_with_fluent_autolog_enables_tensorflow_autologging():
mlflow.autolog()

import keras # pylint: disable=unused-variable,unused-import

assert not autologging_is_disabled(mlflow.tensorflow.FLAVOR_NAME)
assert autologging_is_disabled(mlflow.keras.FLAVOR_NAME)
42 changes: 35 additions & 7 deletions tests/tracking/fluent/test_fluent_autolog.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from collections import namedtuple
from io import StringIO
from unittest import mock
from packaging.version import Version

import mlflow
from mlflow.utils.autologging_utils import (
Expand All @@ -29,7 +30,9 @@

library_to_mlflow_module_without_spark_datasource = {
tensorflow: mlflow.tensorflow,
keras: mlflow.keras,
# NB: In Keras >= 2.6.0, fluent autologging enables TensorFlow logging because Keras APIs
# are aliases for tf.keras APIs in these versions of Keras
keras: mlflow.keras if Version(keras.__version__) < Version("2.6.0") else mlflow.tensorflow,
fastai: mlflow.fastai,
sklearn: mlflow.sklearn,
xgboost: mlflow.xgboost,
Expand Down Expand Up @@ -102,7 +105,17 @@ def test_universal_autolog_does_not_throw_if_specific_autolog_throws_in_standard
mlflow.autolog()
if library != pyspark and library != pyspark.ml:
autolog_mock.assert_not_called()
mlflow.utils.import_hooks.notify_module_loaded(library)

if mlflow_module == mlflow.tensorflow and Version(tensorflow.__version__) >= Version(
"2.6.0"
):
# NB: In TensorFlow >= 2.6.0, TensorFlow unconditionally imports Keras. Fluent
# autologging enablement logic relies on this import behavior.
mlflow.utils.import_hooks.notify_module_loaded(keras)
mlflow.utils.import_hooks.notify_module_loaded(tensorflow)
else:
mlflow.utils.import_hooks.notify_module_loaded(library)

autolog_mock.assert_called_once()


Expand All @@ -122,7 +135,15 @@ def test_universal_autolog_throws_if_specific_autolog_throws_in_test_mode(librar
else:
mlflow.autolog()
with pytest.raises(Exception, match="asdf"):
mlflow.utils.import_hooks.notify_module_loaded(library)
if mlflow_module == mlflow.tensorflow and Version(
tensorflow.__version__
) >= Version("2.6.0"):
# NB: In TensorFlow >= 2.6.0, TensorFlow unconditionally imports Keras. Fluent
# autologging enablement logic relies on this import behavior.
mlflow.utils.import_hooks.notify_module_loaded(keras)
mlflow.utils.import_hooks.notify_module_loaded(tensorflow)
else:
mlflow.utils.import_hooks.notify_module_loaded(library)

autolog_mock.assert_called_once()

Expand All @@ -144,7 +165,14 @@ def test_universal_autolog_calls_specific_autologs_correctly(library, mlflow_mod
args_to_test.update({"log_input_examples": True, "log_model_signatures": True})

mlflow.autolog(**args_to_test)
mlflow.utils.import_hooks.notify_module_loaded(library)

if mlflow_module == mlflow.tensorflow and Version(tensorflow.__version__) >= Version("2.6.0"):
# NB: In TensorFlow >= 2.6.0, TensorFlow unconditionally imports Keras. Fluent
# autologging enablement logic relies on this import behavior.
mlflow.utils.import_hooks.notify_module_loaded(keras)
mlflow.utils.import_hooks.notify_module_loaded(tensorflow)
else:
mlflow.utils.import_hooks.notify_module_loaded(library)

for arg_key, arg_value in args_to_test.items():
assert (
Expand Down Expand Up @@ -257,17 +285,17 @@ def test_autolog_obeys_disabled():
def test_autolog_success_message_obeys_disabled():
with mock.patch("mlflow.tracking.fluent._logger.info") as autolog_logger_mock:
mlflow.autolog(disable=True)
mlflow.utils.import_hooks.notify_module_loaded(tensorflow)
mlflow.utils.import_hooks.notify_module_loaded(sklearn)
autolog_logger_mock.assert_not_called()

mlflow.autolog()
mlflow.utils.import_hooks.notify_module_loaded(tensorflow)
mlflow.utils.import_hooks.notify_module_loaded(sklearn)
autolog_logger_mock.assert_called()

autolog_logger_mock.reset_mock()

mlflow.autolog(disable=False)
mlflow.utils.import_hooks.notify_module_loaded(tensorflow)
mlflow.utils.import_hooks.notify_module_loaded(sklearn)
autolog_logger_mock.assert_called()


Expand Down