Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add timeout for signature/requirement inference during Transformer model logging. #11037

Merged
merged 15 commits into from Feb 8, 2024
10 changes: 10 additions & 0 deletions mlflow/environment_variables.py
Expand Up @@ -225,6 +225,16 @@ def get(self):
"MLFLOW_ARTIFACT_UPLOAD_DOWNLOAD_TIMEOUT", int, None
)

#: Specifies the timeout for model inference with input example(s) when logging/saving a model.
#: MLflow runs a few inference requests against the model to infer model signature and pip
#: requirements. Sometimes the prediction hangs for a long time, especially for a large model.
#: This timeout limits the allowable time for performing a prediction for signature inference
#: and will abort the prediction, falling back to the default signature and pip requirements.
MLFLOW_INPUT_EXAMPLE_INFERENCE_TIMEOUT = _EnvironmentVariable(
"MLFLOW_INPUT_EXAMPLE_INFERENCE_TIMEOUT", int, 180
)


#: Specifies the device intended for use in the predict function - can be used
#: to override behavior where the GPU is used by default when available by
#: setting this environment variable to be ``cpu``. Currently, this
Expand Down
181 changes: 23 additions & 158 deletions mlflow/transformers/__init__.py
Expand Up @@ -7,6 +7,7 @@
import contextlib
import copy
import functools
import importlib
import json
import logging
import os
Expand Down Expand Up @@ -38,11 +39,9 @@
Model,
ModelInputExample,
ModelSignature,
infer_pip_requirements,
infer_signature,
)
from mlflow.models.model import MLMODEL_FILE_NAME
from mlflow.models.utils import _contains_params, _save_example
from mlflow.models.utils import _save_example
from mlflow.protos.databricks_pb2 import (
BAD_REQUEST,
INVALID_PARAMETER_VALUE,
Expand All @@ -55,7 +54,6 @@
_SUPPORTED_LLM_INFERENCE_TASK_TYPES_BY_PIPELINE_TASK,
postprocess_output_for_llm_inference_task,
)
from mlflow.types.schema import ColSpec, Schema, TensorSpec
from mlflow.types.utils import _validate_input_dictionary_contains_only_strings_and_lists_of_strings
from mlflow.utils.annotations import experimental
from mlflow.utils.autologging_utils import (
Expand All @@ -78,6 +76,7 @@
_process_pip_requirements,
_PythonEnv,
_validate_env_arguments,
infer_pip_requirements_with_timeout,
)
from mlflow.utils.file_utils import get_total_file_size, write_to
from mlflow.utils.model_utils import (
Expand All @@ -90,10 +89,21 @@
)
from mlflow.utils.requirements_utils import _get_pinned_requirement

IS_TRANSFORMERS_AVAILABLE = importlib.util.find_spec("transformers") is not None

# The following modules depend on transformers and only imported when it is available
if IS_TRANSFORMERS_AVAILABLE:
from mlflow.transformers.signature import (
_generate_signature_output,
format_input_example_for_special_cases,
infer_or_get_default_signature,
)

# The following import is only used for type hinting
if TYPE_CHECKING:
import torch


FLAVOR_NAME = "transformers"

_CARD_TEXT_FILE_NAME = "model_card.md"
Expand Down Expand Up @@ -485,6 +495,8 @@ def save_model(
# using accelerate iff the model weights have been loaded using a device_map that is
# heterogeneous. There is a distinct possibility for a partial write to occur, causing an
# invalid state of the model's weights in this scenario. Hence, we raise.
# We might be able to remove this check once this PR is merged to transformers:
# https://github.com/huggingface/transformers/issues/20072
if _is_model_distributed_in_memory(built_pipeline.model):
raise MlflowException(
"The model that is attempting to be saved has been loaded into memory "
Expand All @@ -499,7 +511,7 @@ def save_model(
if signature is not None:
mlflow_model.signature = signature
if input_example is not None:
input_example = _format_input_example_for_special_cases(input_example, built_pipeline)
input_example = format_input_example_for_special_cases(input_example, built_pipeline)
_save_example(mlflow_model, input_example, str(path), example_no_conversion)
if metadata is not None:
mlflow_model.metadata = metadata
Expand Down Expand Up @@ -573,10 +585,8 @@ def save_model(
# Currently supported types are NLP-based language tasks which have a pipeline definition
# consisting exclusively of a Model and a Tokenizer.
if _should_add_pyfunc_to_model(built_pipeline):
# For pyfunc supported models, if a signature is not supplied, infer the signature
# from the input_example if provided, otherwise, apply a generic signature.
if mlflow_model.signature is None:
mlflow_model.signature = _get_default_pipeline_signature(
mlflow_model.signature = infer_or_get_default_signature(
pipeline=built_pipeline,
example=input_example,
model_config=model_config or inference_config,
Expand Down Expand Up @@ -628,7 +638,10 @@ def save_model(
if conda_env is None:
if pip_requirements is None:
default_reqs = get_default_pip_requirements(transformers_model.model)
inferred_reqs = infer_pip_requirements(str(path), FLAVOR_NAME, fallback=default_reqs)
# Infer the pip requirements with a timeout to avoid hanging indefinitely at prediction
inferred_reqs = infer_pip_requirements_with_timeout(
str(path), FLAVOR_NAME, fallback=default_reqs
)
default_reqs = sorted(set(inferred_reqs).union(default_reqs))
else:
default_reqs = None
Expand Down Expand Up @@ -1560,149 +1573,6 @@ def _should_add_pyfunc_to_model(pipeline) -> bool:
return True


def _format_input_example_for_special_cases(input_example, pipeline):
"""
Handles special formatting for specific types of Pipelines so that the displayed example
reflects the correct example input structure that mirrors the behavior of the input parsing
for pyfunc.
"""
import transformers

input_data = input_example[0] if isinstance(input_example, tuple) else input_example

if (
isinstance(pipeline, transformers.ZeroShotClassificationPipeline)
and isinstance(input_data, dict)
and isinstance(input_data["candidate_labels"], list)
):
input_data["candidate_labels"] = json.dumps(input_data["candidate_labels"])
return input_data if not isinstance(input_example, tuple) else (input_data, input_example[1])


def _get_default_pipeline_signature(
pipeline, example=None, model_config=None, flavor_config=None
) -> ModelSignature:
"""
Assigns a default ModelSignature for a given Pipeline type that has pyfunc support. These
default signatures should only be generated and assigned when saving a model iff the user
has not supplied a signature.
For signature inference in some Pipelines that support complex input types, an input example
is needed.
"""

import transformers

if example:
try:
params = None
if _contains_params(example):
example, params = example
example = _format_input_example_for_special_cases(example, pipeline)
prediction = generate_signature_output(
pipeline=pipeline,
data=example,
model_config=model_config,
params=params,
flavor_config=flavor_config,
)
return infer_signature(example, prediction, params)
except Exception as e:
_logger.warning(
"Attempted to generate a signature for the saved model or pipeline "
f"but encountered an error: {e}"
)
raise
else:
if isinstance(
pipeline,
(
transformers.TokenClassificationPipeline,
transformers.ConversationalPipeline,
transformers.TranslationPipeline,
transformers.FillMaskPipeline,
transformers.TextGenerationPipeline,
transformers.Text2TextGenerationPipeline,
),
):
return ModelSignature(
inputs=Schema([ColSpec("string")]), outputs=Schema([ColSpec("string")])
)
elif isinstance(
pipeline,
(
transformers.TextClassificationPipeline,
transformers.ImageClassificationPipeline,
),
):
return ModelSignature(
inputs=Schema([ColSpec("string")]),
outputs=Schema([ColSpec("string", name="label"), ColSpec("double", name="score")]),
)
elif isinstance(pipeline, transformers.ZeroShotClassificationPipeline):
return ModelSignature(
inputs=Schema(
[
ColSpec("string", name="sequences"),
ColSpec("string", name="candidate_labels"),
ColSpec("string", name="hypothesis_template"),
]
),
outputs=Schema(
[
ColSpec("string", name="sequence"),
ColSpec("string", name="labels"),
ColSpec("double", name="scores"),
]
),
)
elif isinstance(pipeline, transformers.AutomaticSpeechRecognitionPipeline):
return ModelSignature(
inputs=Schema([ColSpec("binary")]),
outputs=Schema([ColSpec("string")]),
)
elif isinstance(pipeline, transformers.AudioClassificationPipeline):
return ModelSignature(
inputs=Schema([ColSpec("binary")]),
outputs=Schema([ColSpec("double", name="score"), ColSpec("string", name="label")]),
)
elif isinstance(
pipeline,
(
transformers.TableQuestionAnsweringPipeline,
transformers.QuestionAnsweringPipeline,
),
):
column_1 = None
column_2 = None
if isinstance(pipeline, transformers.TableQuestionAnsweringPipeline):
column_1 = "query"
column_2 = "table"
elif isinstance(pipeline, transformers.QuestionAnsweringPipeline):
column_1 = "question"
column_2 = "context"
return ModelSignature(
inputs=Schema(
[
ColSpec("string", name=column_1),
ColSpec("string", name=column_2),
]
),
outputs=Schema([ColSpec("string")]),
)
elif isinstance(pipeline, transformers.FeatureExtractionPipeline):
return ModelSignature(
inputs=Schema([ColSpec("string")]),
outputs=Schema([TensorSpec(np.dtype("float64"), [-1], "double")]),
)
else:
_logger.warning(
"An unsupported Pipeline type was supplied for signature inference. "
"Either provide an `input_example` or generate a signature manually "
"via `infer_signature` if you would like to have a signature recorded "
"in the MLmodel file."
)


class _TransformersModel(NamedTuple):
"""
Type validator class for models that are submitted as a dictionary for saving and logging.
Expand Down Expand Up @@ -1867,12 +1737,7 @@ def generate_signature_output(pipeline, data, model_config=None, params=None, fl
error_code=INVALID_PARAMETER_VALUE,
)

pyfunc_model = _TransformersWrapper(
pipeline=pipeline,
flavor_config=flavor_config,
model_config=model_config,
)
return pyfunc_model.predict(data, params=params)
return _generate_signature_output(pipeline, data, model_config, params)


class _TransformersWrapper:
Expand Down