Skip to content

Commit

Permalink
Fix autologging compatibility with Keras >= 2.6.0 (#4766)
Browse files Browse the repository at this point in the history
* Increase HTTP timeout to 90s. Disabled cloud storage HTTP timeout. (#4764)

* Increase HTTP timeout to 120s. Disabled cloud storage HTTP timeout.

Signed-off-by: dbczumar <corey.zumar@databricks.com>

* Include keras conditionally

Signed-off-by: dbczumar <corey.zumar@databricks.com>

* Fixes

Signed-off-by: dbczumar <corey.zumar@databricks.com>

* Fix

Signed-off-by: dbczumar <corey.zumar@databricks.com>

* Fix root cause

Signed-off-by: dbczumar <corey.zumar@databricks.com>

* docstring

Signed-off-by: dbczumar <corey.zumar@databricks.com>

* Some test cases

Signed-off-by: dbczumar <corey.zumar@databricks.com>

* Some test cases

Signed-off-by: dbczumar <corey.zumar@databricks.com>

* Tests

Signed-off-by: dbczumar <corey.zumar@databricks.com>

* Format

Signed-off-by: dbczumar <corey.zumar@databricks.com>

* Test fixes

Signed-off-by: dbczumar <corey.zumar@databricks.com>

* Test fix 2

Signed-off-by: dbczumar <corey.zumar@databricks.com>

* Remove keras change

Signed-off-by: dbczumar <corey.zumar@databricks.com>

* Use is

Signed-off-by: dbczumar <corey.zumar@databricks.com>

* Remove unused modules

Signed-off-by: dbczumar <corey.zumar@databricks.com>

* Use fixtures

Signed-off-by: dbczumar <corey.zumar@databricks.com>

* Docstring

Signed-off-by: dbczumar <corey.zumar@databricks.com>

* Fix fixtures

Signed-off-by: dbczumar <corey.zumar@databricks.com>

* Lint fixes

Signed-off-by: dbczumar <corey.zumar@databricks.com>

* Format

Signed-off-by: dbczumar <corey.zumar@databricks.com>

* Try preserve find module

Signed-off-by: dbczumar <corey.zumar@databricks.com>

* fix

Signed-off-by: Weichen Xu <weichen.xu@databricks.com>

* fix2

Signed-off-by: Weichen Xu <weichen.xu@databricks.com>

* Simplify fluent test cases

Signed-off-by: dbczumar <corey.zumar@databricks.com>

* Format

Signed-off-by: dbczumar <corey.zumar@databricks.com>

* Tweaks, add a warning

Signed-off-by: dbczumar <corey.zumar@databricks.com>

* Test excludee

Signed-off-by: dbczumar <corey.zumar@databricks.com>

* Reverts

Signed-off-by: dbczumar <corey.zumar@databricks.com>

Co-authored-by: jinzhang21 <78067366+jinzhang21@users.noreply.github.com>
Co-authored-by: Weichen Xu <weichen.xu@databricks.com>
  • Loading branch information
3 people committed Sep 3, 2021
1 parent 287b16b commit f47e3ec
Show file tree
Hide file tree
Showing 5 changed files with 256 additions and 9 deletions.
10 changes: 10 additions & 0 deletions mlflow/keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -673,6 +673,16 @@ def autolog(

_raise_deprecation_warning()

if Version(keras.__version__) >= Version("2.6.0"):
warnings.warn(
(
"Autologging support for keras >= 2.6.0 has been deprecated and will be removed in "
"a future MLflow release. Use `mlflow.tensorflow.autolog()` instead."
),
FutureWarning,
stacklevel=2,
)

def getKerasCallback(metrics_logger):
class __MLflowKerasCallback(keras.callbacks.Callback, metaclass=ExceptionSafeClass):
"""
Expand Down
58 changes: 57 additions & 1 deletion mlflow/tracking/fluent.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import time
import logging
import inspect
from packaging.version import Version
from typing import Any, Dict, List, Optional, Union, TYPE_CHECKING

from mlflow.entities import Experiment, Run, RunInfo, RunStatus, Param, RunTag, Metric, ViewType
Expand Down Expand Up @@ -1463,9 +1464,64 @@ def setup_autologging(module):
# for each autolog library (except pyspark), register a post-import hook.
# this way, we do not send any errors to the user until we know they are using the library.
# the post-import hook also retroactively activates for previously-imported libraries.
for module in list(set(LIBRARY_TO_AUTOLOG_FN.keys()) - set(["pyspark", "pyspark.ml"])):
for module in list(
set(LIBRARY_TO_AUTOLOG_FN.keys()) - set(["tensorflow", "keras", "pyspark", "pyspark.ml"])
):
register_post_import_hook(setup_autologging, module, overwrite=True)

FULLY_IMPORTED_KERAS = False
TF_AUTOLOG_SETUP_CALLED = False

def conditionally_set_up_keras_autologging(keras_module):
nonlocal FULLY_IMPORTED_KERAS, TF_AUTOLOG_SETUP_CALLED
FULLY_IMPORTED_KERAS = True

if Version(keras_module.__version__) >= Version("2.6.0"):
# NB: Keras unconditionally depends on TensorFlow beginning with Version 2.6.0, and
# many classes defined in the `keras` module are aliases of classes in the `tf.keras`
# module. Accordingly, TensorFlow autologging serves as a replacement for Keras
# autologging in Keras >= 2.6.0
try:
import tensorflow

setup_autologging(tensorflow)
TF_AUTOLOG_SETUP_CALLED = True
except Exception as e:
_logger.debug(
"Failed to set up TensorFlow autologging for tf.keras models upon"
" Keras library import: %s",
str(e),
)
raise
else:
setup_autologging(keras_module)

register_post_import_hook(conditionally_set_up_keras_autologging, "keras", overwrite=True)

def set_up_tensorflow_autologging(tensorflow_module):
import sys

nonlocal FULLY_IMPORTED_KERAS, TF_AUTOLOG_SETUP_CALLED
if "keras" in sys.modules and not FULLY_IMPORTED_KERAS:
# In Keras >= 2.6.0, importing Keras imports the TensorFlow library, which can
# trigger this autologging import hook for TensorFlow before the entire Keras import
# procedure is completed. Attempting to set up autologging before the Keras import
# procedure has completed will result in a failure due to the unavailability of
# certain modules. In this case, we terminate the TensorFlow autologging import hook
# and rely on the Keras autologging import hook to successfully set up TensorFlow
# autologging for tf.keras models once the Keras import procedure has completed
return

# By design, in Keras >= 2.6.0, Keras needs to enable tensorflow autologging so that
# tf.keras models always use tensorflow autologging, rather than vanilla keras autologging.
# As a result, Keras autologging must call `mlflow.tensorflow.autolog()` in Keras >= 2.6.0.
# Accordingly, we insert this check to ensure that importing tensorflow, which may import
# keras, does not enable tensorflow autologging twice.
if not TF_AUTOLOG_SETUP_CALLED:
setup_autologging(tensorflow_module)

register_post_import_hook(set_up_tensorflow_autologging, "tensorflow", overwrite=True)

# for pyspark, we activate autologging immediately, without waiting for a module import.
# this is because on Databricks a SparkSession already exists and the user can directly
# interact with it, and this activity should be logged.
Expand Down
37 changes: 37 additions & 0 deletions mlflow/utils/import_hooks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,43 @@ def find_module(self, fullname, path=None):
finally:
del self.in_progress[fullname]

@synchronized(_post_import_hooks_lock)
@synchronized(_import_error_hooks_lock)
def find_spec(self, fullname, path, target=None): # pylint: disable=unused-argument
# If the module being imported is not one we have registered
# import hooks for, we can return immediately. We will
# take no further part in the importing of this module.

if fullname not in _post_import_hooks and fullname not in _import_error_hooks:
return None

# When we are interested in a specific module, we will call back
# into the import system a second time to defer to the import
# finder that is supposed to handle the importing of the module.
# We set an in progress flag for the target module so that on
# the second time through we don't trigger another call back
# into the import system and cause a infinite loop.

if fullname in self.in_progress:
return None

self.in_progress[fullname] = True

# Now call back into the import system again.

try:
import importlib.util

spec = importlib.util.find_spec(fullname)
# Replace the module spec's loader with a wrapped version that executes import
# hooks when the module is loaded
spec.loader = _ImportHookChainedLoader(spec.loader)
return spec
except (ImportError, AttributeError):
notify_module_import_error(fullname)
finally:
del self.in_progress[fullname]


# Decorator for marking that a function should be called as a post
# import hook when the target module is imported.
Expand Down
118 changes: 117 additions & 1 deletion tests/tensorflow_autolog/test_tensorflow2_autolog.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import collections
import pytest
import sys
from packaging.version import Version

import numpy as np
Expand All @@ -12,7 +13,7 @@
import mlflow
import mlflow.tensorflow
import mlflow.keras
from mlflow.utils.autologging_utils import BatchMetricsLogger
from mlflow.utils.autologging_utils import BatchMetricsLogger, autologging_is_disabled
from unittest.mock import patch

import os
Expand Down Expand Up @@ -53,6 +54,29 @@ def manual_run(request):
mlflow.end_run()


@pytest.fixture
def clear_tf_keras_imports():
"""
Simulates a state where `tensorflow` and `keras` are not imported by removing these
libraries from the `sys.modules` dictionary. This is useful for testing the interaction
between TensorFlow / Keras and the fluent `mlflow.autolog()` API because it will cause import
hooks to be re-triggered upon re-import after `mlflow.autolog()` is enabled.
"""
sys.modules.pop("tensorflow", None)
sys.modules.pop("keras", None)


@pytest.fixture(autouse=True)
def clear_fluent_autologging_import_hooks():
"""
Clears import hooks for MLflow fluent autologging (`mlflow.autolog()`) between tests
to ensure that interactions between fluent autologging and TensorFlow / tf.keras can
be tested successfully
"""
mlflow.utils.import_hooks._post_import_hooks.pop("tensorflow", None)
mlflow.utils.import_hooks._post_import_hooks.pop("keras", None)


def create_tf_keras_model():
model = tf.keras.Sequential()

Expand Down Expand Up @@ -808,3 +832,95 @@ def generator():
assert params["steps_per_epoch"] == "1"
assert "accuracy" in metrics
assert "loss" in metrics


@pytest.mark.large
@pytest.mark.usefixtures("clear_tf_keras_imports")
def test_fluent_autolog_with_tf_keras_logs_expected_content(
random_train_data, random_one_hot_labels
):
"""
Guards against previously-exhibited issues where using the fluent `mlflow.autolog()` API with
`tf.keras` Models did not work due to conflicting patches set by both the
`mlflow.tensorflow.autolog()` and the `mlflow.keras.autolog()` APIs.
"""
mlflow.autolog()

model = create_tf_keras_model()

with mlflow.start_run() as run:
model.fit(random_train_data, random_one_hot_labels, epochs=10)

client = mlflow.tracking.MlflowClient()
run_data = client.get_run(run.info.run_id).data
assert "accuracy" in run_data.metrics
assert "epochs" in run_data.params

artifacts = client.list_artifacts(run.info.run_id)
artifacts = map(lambda x: x.path, artifacts)
assert "model" in artifacts


@pytest.mark.large
@pytest.mark.skipif(
Version(tf.__version__) < Version("2.6.0"),
reason=("TensorFlow only has a hard dependency on Keras in version >= 2.6.0"),
)
@pytest.mark.usefixtures("clear_tf_keras_imports")
def test_fluent_autolog_with_tf_keras_preserves_v2_model_reference():
"""
Verifies that, in TensorFlow >= 2.6.0, `tensorflow.keras.Model` refers to the correct class in
the correct module after `mlflow.autolog()` is called, guarding against previously identified
compatibility issues between recent versions of TensorFlow and MLflow's internal utility for
setting up autologging import hooks.
"""
mlflow.autolog()

import tensorflow.keras
from keras.api._v2.keras import Model as ModelV2

assert tensorflow.keras.Model is ModelV2


@pytest.mark.usefixtures("clear_tf_keras_imports")
def test_import_tensorflow_with_fluent_autolog_enables_tf_autologging():
mlflow.autolog()

import tensorflow # pylint: disable=unused-variable,unused-import,reimported

assert not autologging_is_disabled(mlflow.tensorflow.FLAVOR_NAME)

# NB: For backwards compatibility, fluent autologging enables TensorFlow and
# Keras autologging upon tensorflow import in TensorFlow 2.5.1
if Version(tf.__version__) != Version("2.5.1"):
assert autologging_is_disabled(mlflow.keras.FLAVOR_NAME)


@pytest.mark.large
@pytest.mark.usefixtures("clear_tf_keras_imports")
def test_import_tf_keras_with_fluent_autolog_enables_tf_autologging():
mlflow.autolog()

import tensorflow.keras # pylint: disable=unused-variable,unused-import

assert not autologging_is_disabled(mlflow.tensorflow.FLAVOR_NAME)

# NB: For backwards compatibility, fluent autologging enables TensorFlow and
# Keras autologging upon tf.keras import in TensorFlow 2.5.1
if Version(tf.__version__) != Version("2.5.1"):
assert autologging_is_disabled(mlflow.keras.FLAVOR_NAME)


@pytest.mark.large
@pytest.mark.skipif(
Version(tf.__version__) < Version("2.6.0"),
reason=("TensorFlow autologging is not used for vanilla Keras models in Keras < 2.6.0"),
)
@pytest.mark.usefixtures("clear_tf_keras_imports")
def test_import_keras_with_fluent_autolog_enables_tensorflow_autologging():
mlflow.autolog()

import keras # pylint: disable=unused-variable,unused-import

assert not autologging_is_disabled(mlflow.tensorflow.FLAVOR_NAME)
assert autologging_is_disabled(mlflow.keras.FLAVOR_NAME)
42 changes: 35 additions & 7 deletions tests/tracking/fluent/test_fluent_autolog.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from collections import namedtuple
from io import StringIO
from unittest import mock
from packaging.version import Version

import mlflow
from mlflow.utils.autologging_utils import (
Expand All @@ -29,7 +30,9 @@

library_to_mlflow_module_without_spark_datasource = {
tensorflow: mlflow.tensorflow,
keras: mlflow.keras,
# NB: In Keras >= 2.6.0, fluent autologging enables TensorFlow logging because Keras APIs
# are aliases for tf.keras APIs in these versions of Keras
keras: mlflow.keras if Version(keras.__version__) < Version("2.6.0") else mlflow.tensorflow,
fastai: mlflow.fastai,
sklearn: mlflow.sklearn,
xgboost: mlflow.xgboost,
Expand Down Expand Up @@ -102,7 +105,17 @@ def test_universal_autolog_does_not_throw_if_specific_autolog_throws_in_standard
mlflow.autolog()
if library != pyspark and library != pyspark.ml:
autolog_mock.assert_not_called()
mlflow.utils.import_hooks.notify_module_loaded(library)

if mlflow_module == mlflow.tensorflow and Version(tensorflow.__version__) >= Version(
"2.6.0"
):
# NB: In TensorFlow >= 2.6.0, TensorFlow unconditionally imports Keras. Fluent
# autologging enablement logic relies on this import behavior.
mlflow.utils.import_hooks.notify_module_loaded(keras)
mlflow.utils.import_hooks.notify_module_loaded(tensorflow)
else:
mlflow.utils.import_hooks.notify_module_loaded(library)

autolog_mock.assert_called_once()


Expand All @@ -122,7 +135,15 @@ def test_universal_autolog_throws_if_specific_autolog_throws_in_test_mode(librar
else:
mlflow.autolog()
with pytest.raises(Exception, match="asdf"):
mlflow.utils.import_hooks.notify_module_loaded(library)
if mlflow_module == mlflow.tensorflow and Version(
tensorflow.__version__
) >= Version("2.6.0"):
# NB: In TensorFlow >= 2.6.0, TensorFlow unconditionally imports Keras. Fluent
# autologging enablement logic relies on this import behavior.
mlflow.utils.import_hooks.notify_module_loaded(keras)
mlflow.utils.import_hooks.notify_module_loaded(tensorflow)
else:
mlflow.utils.import_hooks.notify_module_loaded(library)

autolog_mock.assert_called_once()

Expand All @@ -144,7 +165,14 @@ def test_universal_autolog_calls_specific_autologs_correctly(library, mlflow_mod
args_to_test.update({"log_input_examples": True, "log_model_signatures": True})

mlflow.autolog(**args_to_test)
mlflow.utils.import_hooks.notify_module_loaded(library)

if mlflow_module == mlflow.tensorflow and Version(tensorflow.__version__) >= Version("2.6.0"):
# NB: In TensorFlow >= 2.6.0, TensorFlow unconditionally imports Keras. Fluent
# autologging enablement logic relies on this import behavior.
mlflow.utils.import_hooks.notify_module_loaded(keras)
mlflow.utils.import_hooks.notify_module_loaded(tensorflow)
else:
mlflow.utils.import_hooks.notify_module_loaded(library)

for arg_key, arg_value in args_to_test.items():
assert (
Expand Down Expand Up @@ -257,17 +285,17 @@ def test_autolog_obeys_disabled():
def test_autolog_success_message_obeys_disabled():
with mock.patch("mlflow.tracking.fluent._logger.info") as autolog_logger_mock:
mlflow.autolog(disable=True)
mlflow.utils.import_hooks.notify_module_loaded(tensorflow)
mlflow.utils.import_hooks.notify_module_loaded(sklearn)
autolog_logger_mock.assert_not_called()

mlflow.autolog()
mlflow.utils.import_hooks.notify_module_loaded(tensorflow)
mlflow.utils.import_hooks.notify_module_loaded(sklearn)
autolog_logger_mock.assert_called()

autolog_logger_mock.reset_mock()

mlflow.autolog(disable=False)
mlflow.utils.import_hooks.notify_module_loaded(tensorflow)
mlflow.utils.import_hooks.notify_module_loaded(sklearn)
autolog_logger_mock.assert_called()


Expand Down

0 comments on commit f47e3ec

Please sign in to comment.