Skip to content

Commit

Permalink
feedback
Browse files Browse the repository at this point in the history
Signed-off-by: B-Step62 <yuki.watanabe@databricks.com>
  • Loading branch information
B-Step62 committed Feb 7, 2024
1 parent 186d493 commit ca93374
Show file tree
Hide file tree
Showing 8 changed files with 41 additions and 22 deletions.
3 changes: 2 additions & 1 deletion mlflow/environment_variables.py
Expand Up @@ -228,7 +228,8 @@ def get(self):
#: Specifies the timeout for model inference with input example(s) when logging/saving a model.
#: MLflow runs a few inference requests against the model to infer model signature and pip
#: requirements. Sometimes the prediction hangs for a long time, especially for a large model.
#: This timeout avoid the hanging and fall back to the default signature and pip requirements.
#: This timeout limits the allowable time for performing a prediction for signature inference
#: and will abort the prediction, falling back to the default signature and pip requirements.
MLFLOW_INPUT_EXAMPLE_INFERENCE_TIMEOUT = _EnvironmentVariable(
"MLFLOW_INPUT_EXAMPLE_INFERENCE_TIMEOUT", int, 180
)
Expand Down
2 changes: 1 addition & 1 deletion mlflow/transformers/__init__.py
Expand Up @@ -91,7 +91,7 @@

IS_TRANSFORMERS_AVAILABLE = importlib.util.find_spec("transformers") is not None

# The following modules depends on transformers and only imported when it is available
# The following modules depend on transformers and only imported when it is available
if IS_TRANSFORMERS_AVAILABLE:
from mlflow.transformers.signature import (
_generate_signature_output,
Expand Down
5 changes: 2 additions & 3 deletions mlflow/transformers/signature.py
Expand Up @@ -11,7 +11,7 @@
from mlflow.protos.databricks_pb2 import INVALID_PARAMETER_VALUE
from mlflow.types.schema import ColSpec, DataType, Schema, TensorSpec
from mlflow.utils.process import _IS_UNIX
from mlflow.utils.timeout import MLflowTimeoutError, run_with_timeout
from mlflow.utils.timeout import MlflowTimeoutError, run_with_timeout

_logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -104,14 +104,13 @@ def infer_or_get_default_signature(
return _infer_signature_with_prediction(
pipeline, example, model_config, flavor_config, timeout
)
except MLflowTimeoutError:
except MlflowTimeoutError:
_logger.warning(
"Attempted to generate a signature for the saved model but prediction operation "
f"timed out after {timeout} seconds. Falling back to the default signature for the "
"pipeline. You can specify a signature manually or increase the timeout "
f"by setting the environment variable {MLFLOW_INPUT_EXAMPLE_INFERENCE_TIMEOUT}"
)
pass
except Exception as e:
_logger.error(
"Attempted to generate a signature for the saved model or pipeline "
Expand Down
24 changes: 14 additions & 10 deletions mlflow/utils/_capture_transformers_modules.py
Expand Up @@ -7,6 +7,7 @@

import mlflow
from mlflow.exceptions import MlflowException
from mlflow.protos.databricks_pb2 import INVALID_PARAMETER_VALUE
from mlflow.utils._capture_modules import (
_CaptureImportedModules,
parse_args,
Expand All @@ -32,16 +33,6 @@ def _record_imported_module(self, full_module_name):
raise ImportError(f"Disabled package {full_module_name}")
return super()._record_imported_module(full_module_name)

def __enter__(self):
# Patch the environment variables to disable module_to_throw
# https://github.com/huggingface/transformers/blob/3658488ff77ff8d45101293e749263acf437f4d5/src/transformers/utils/import_utils.py#L60-L62
if self.module_to_throw == "tensorflow":
os.environ["USE_TORCH"] = "TRUE"
elif self.module_to_throw == "torch":
os.environ["USE_TF"] = "TRUE"

return super().__enter__()


def main():
args = parse_args()
Expand All @@ -60,6 +51,19 @@ def main():

if module_to_throw == "":
raise MlflowException("Please specify the module to throw.")
elif module_to_throw == "tensorflow":
if not os.environ.get("USE_TORCH", None) == "TRUE":
raise MlflowException(
"The environment variable USE_TORCH has to be set to TRUE to disable Tensorflow.",
error_code=INVALID_PARAMETER_VALUE,
)
elif module_to_throw == "torch":
if not os.environ.get("USE_TF", None) == "TRUE":
raise MlflowException(
"The environment variable USE_TF has to be set to TRUE to disable Pytorch.",
error_code=INVALID_PARAMETER_VALUE,
)

cap_cm = _CaptureImportedModulesForHF(module_to_throw)
store_imported_modules(cap_cm, model_path, flavor, output_file)

Expand Down
4 changes: 2 additions & 2 deletions mlflow/utils/environment.py
Expand Up @@ -19,7 +19,7 @@
_parse_requirements,
warn_dependency_requirement_mismatches,
)
from mlflow.utils.timeout import MLflowTimeoutError, run_with_timeout
from mlflow.utils.timeout import MlflowTimeoutError, run_with_timeout
from mlflow.version import VERSION

_logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -412,7 +412,7 @@ def infer_pip_requirements_with_timeout(model_uri, flavor, fallback):
return infer_pip_requirements(model_uri, flavor, fallback)
except Exception as e:
if fallback is not None:
if isinstance(e, MLflowTimeoutError):
if isinstance(e, MlflowTimeoutError):
msg = (
"Attempted to infer pip requirements for the saved model or pipeline but the "
f"operation timed out in {timeout} seconds. Fall back to return {fallback}. "
Expand Down
11 changes: 9 additions & 2 deletions mlflow/utils/requirements_utils.py
Expand Up @@ -281,7 +281,14 @@ def _capture_imported_modules(model_uri, flavor):
from mlflow.utils import _capture_transformers_modules

for module_to_throw in ["tensorflow", "torch"]:
transformer_env = {"USE_TF": "TRUE"} if module_to_throw == "torch" else {"USE_TORCH": "TRUE"}
# NB: Setting USE_TF or USE_TORCH here as Transformers only checks these env
# variable on the first import of the library, which could happen anytime during
# the model loading process (or even mlflow import). When these variables are not
# set, Transformers import some torch/tensorflow modules even if they are not
# used by the model, resulting in false positives in the captured modules.
transformer_env = (
{"USE_TF": "TRUE"} if module_to_throw == "torch" else {"USE_TORCH": "TRUE"}
)
try:
_run_command(
[
Expand All @@ -299,7 +306,7 @@ def _capture_imported_modules(model_uri, flavor):
module_to_throw,
],
timeout_seconds=process_timeout,
env={**main_env, **transformer_env}
env={**main_env, **transformer_env},
)
with open(output_file) as f:
return f.read().splitlines()
Expand Down
12 changes: 9 additions & 3 deletions mlflow/utils/timeout.py
@@ -1,10 +1,12 @@
import signal
from contextlib import contextmanager

from mlflow.exceptions import MlflowException
from mlflow.protos.databricks_pb2 import NOT_IMPLEMENTED
from mlflow.utils.process import _IS_UNIX


class MLflowTimeoutError(Exception):
class MlflowTimeoutError(Exception):
pass


Expand All @@ -22,10 +24,14 @@ def run_with_timeout(seconds):
model.predict(data)
```
"""
assert _IS_UNIX, "Timeouts are not implemented yet for non-Unix platforms"
if not _IS_UNIX:
raise MlflowException(
"Timeouts are not implemented yet for non-Unix platforms",
error_code=NOT_IMPLEMENTED,
)

def signal_handler(signum, frame):
raise MLflowTimeoutError(f"Operation timed out after {seconds} seconds")
raise MlflowTimeoutError(f"Operation timed out after {seconds} seconds")

signal.signal(signal.SIGALRM, signal_handler)
signal.alarm(seconds)
Expand Down
2 changes: 2 additions & 0 deletions tests/transformers/test_transformers_model_export.py
Expand Up @@ -927,6 +927,8 @@ def test_transformers_tf_model_log_without_conda_env_uses_default_env_with_expec
pip_requirements = _get_deps_from_requirement_file(model_uri)
assert "tensorflow" in pip_requirements
assert "torch" not in pip_requirements
# Accelerate installs Pytorch along with it, so it should not be present in the requirements
assert "accelerate" not in pip_requirements


def test_transformers_pt_model_log_without_conda_env_uses_default_env_with_expected_dependencies(
Expand Down

0 comments on commit ca93374

Please sign in to comment.