Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add timeout for signature/requirement inference during Transformer model logging. #11037

Merged
merged 15 commits into from Feb 8, 2024

Conversation

B-Step62
Copy link
Collaborator

@B-Step62 B-Step62 commented Feb 6, 2024

🛠 DevTools 🛠

Open in GitHub Codespaces

Install mlflow from this PR

pip install git+https://github.com/mlflow/mlflow.git@refs/pull/11037/merge

Checkout with GitHub CLI

gh pr checkout 11037

What changes are proposed in this pull request?

Problem

Transformers model saving involves a few model predictions if an input example is provided.

  1. Prediction to generate model output for inferring model signature.
  2. Prediction to track imported module to infer pip requirements.

However, this can take significant long for huge models like LLM, because we don't support saving models with optimized device mapping i.e. distributed to multiple devices. This is indeed a limitation of Transformers save_pretrained method we are using while logging context). As a result, users can only save models on single CPU/GPU, which can take hours or more for huge models.

What did't work

Initially, I tried to solve this problem by allowing saving model with device_mapping, by porting model to single CPU/GPU just before calling save_pretrained. If this is possible, we can run prediction for signature/requirement inference with optimized device setting, which makes the latency acceptable short.

However, it turns out moving models to different device is challenging, when it is originally distributed to multiple devices. For example, running model.to(torch.device("cpu")) will raise RuntimeError: You can't move a model that has some modules offloaded to cpu or disk..

We could implement device handling using native torch/tf libraries, but just found that there is a PR for extending save_pretrained() to support models loaded with device mapping. Hence, I think it we can wait this PR to be merged.

What I did eventually

To mitigate the "stuck" issue, I just added timeout to those predictions.

  1. add 60 sec timeout to signature inference, with fallback to default signature for each pipeline type.
  2. add 60 sec timeout to requirements inference, with fallback to default pip requirements.

This PR also includes small refactoring - extracting signature related logic to a separate file.

How is this PR tested?

  • Existing unit/integration tests
  • New unit/integration tests
  • Manual tests

Tested serving via Docker container (with accelerate)

Screenshot 2024-02-08 at 11 44 25

Does this PR require documentation update?

  • No. You can skip the rest of this section.
  • Yes. I've updated:
    • Examples
    • API references
    • Instructions

Release Notes

Is this a user-facing change?

  • No. You can skip the rest of this section.
  • Yes. Give a description of this change to be included in the release notes for MLflow users.

Update the Transformers log_model() API to enforce timeout for signature and requirement inference, with a static fallback, to avoid the hanging issue when saving large models.

What component(s), interfaces, languages, and integrations does this PR affect?

Components

  • area/artifacts: Artifact stores and artifact logging
  • area/build: Build and test infrastructure for MLflow
  • area/deployments: MLflow Deployments client APIs, server, and third-party Deployments integrations
  • area/docs: MLflow documentation pages
  • area/examples: Example code
  • area/model-registry: Model Registry service, APIs, and the fluent client calls for Model Registry
  • area/models: MLmodel format, model serialization/deserialization, flavors
  • area/recipes: Recipes, Recipe APIs, Recipe configs, Recipe Templates
  • area/projects: MLproject format, project running backends
  • area/scoring: MLflow Model server, model deployment tools, Spark UDFs
  • area/server-infra: MLflow Tracking server backend
  • area/tracking: Tracking Service, tracking client APIs, autologging

Interface

  • area/uiux: Front-end, user experience, plotting, JavaScript, JavaScript dev server
  • area/docker: Docker use across MLflow's components, such as MLflow Projects and MLflow Models
  • area/sqlalchemy: Use of SQLAlchemy in the Tracking Service or Model Registry
  • area/windows: Windows support

Language

  • language/r: R APIs and clients
  • language/java: Java APIs and clients
  • language/new: Proposals for new client languages

Integrations

  • integrations/azure: Azure and Azure ML integrations
  • integrations/sagemaker: SageMaker integrations
  • integrations/databricks: Databricks integrations

How should the PR be classified in the release notes? Choose one:

  • rn/none - No description will be included. The PR will be mentioned only by the PR number in the "Small Bugfixes and Documentation Updates" section
  • rn/breaking-change - The PR will be mentioned in the "Breaking Changes" section
  • rn/feature - A new user-facing feature worth mentioning in the release notes
  • rn/bug-fix - A user-facing bug fix worth mentioning in the release notes
  • rn/documentation - A user-facing documentation change worth mentioning in the release notes

Copy link

github-actions bot commented Feb 6, 2024

Documentation preview for c92e64b will be available when this CircleCI job completes successfully.

More info

@github-actions github-actions bot added area/models MLmodel format, model serialization/deserialization, flavors rn/bug-fix Mention under Bug Fixes in Changelogs. labels Feb 6, 2024
@daniellok-db
Copy link
Collaborator

overall it looks good to me, i think this is a great idea especially if the default signatures are good enough for typical use-cases!

i agree that we probably wait for someone to raise a feature request if we want the timeout to be configurable, but i do think 60 seconds is a bit short based on my own usage of transformers (though not sure if i've been doing things in an optimized way). maybe we can increase the timeout?

it looks like there are some test failures but happy to accept after those are resolved and if nobody else has any concerns!

@B-Step62
Copy link
Collaborator Author

B-Step62 commented Feb 7, 2024

hmmm the failure of test_transformers_tf_model_log_without_conda_env_uses_default_env_with_expected_dependencies is so weird, I can't reproduce it with the same package versions.

tests/transformers/test_transformers_model_export.py::test_transformers_tf_model_log_without_conda_env_uses_default_env_with_expected_dependencies PASSED | MEM 1.6/61.8 GB | DISK 205.3/484.6 GB [100%]

#: Specifies the timeout for model inference with input example(s) when logging/saving a model.
#: MLflow runs a few inference requests against the model to infer model signature and pip
#: requirements. Sometimes the prediction hangs for a long time, especially for a large model.
#: This timeout avoid the hanging and fall back to the default signature and pip requirements.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
#: This timeout avoid the hanging and fall back to the default signature and pip requirements.
#: This timeout limits the allowable time for performing a prediction for signature inference and will abort the prediction, falling back to the default signature and pip requirements.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How can we disable timeout? Setting this environment variable to 0?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can set empty string or extremely large number like 99999

Signed-off-by: B-Step62 <yuki.watanabe@databricks.com>
Signed-off-by: B-Step62 <yuki.watanabe@databricks.com>
Signed-off-by: B-Step62 <yuki.watanabe@databricks.com>
Signed-off-by: B-Step62 <yuki.watanabe@databricks.com>
Signed-off-by: B-Step62 <yuki.watanabe@databricks.com>
Signed-off-by: B-Step62 <yuki.watanabe@databricks.com>
Signed-off-by: B-Step62 <yuki.watanabe@databricks.com>
Signed-off-by: B-Step62 <yuki.watanabe@databricks.com>
Signed-off-by: B-Step62 <yuki.watanabe@databricks.com>
Signed-off-by: B-Step62 <yuki.watanabe@databricks.com>
Signed-off-by: B-Step62 <yuki.watanabe@databricks.com>
mlflow/utils/timeout.py Outdated Show resolved Hide resolved
mlflow/utils/timeout.py Outdated Show resolved Hide resolved
Signed-off-by: B-Step62 <yuki.watanabe@databricks.com>
Comment on lines 11 to 42
@contextmanager
def run_with_timeout(seconds):
"""
Context manager to runs a block of code with a timeout. If the block of code takes longer
than `seconds` to execute, a `TimeoutError` is raised.
NB: This function uses Unix signals to implement the timeout, so it is not thread-safe.
Also it does not work on non-Unix platforms such as Windows.

E.g.
```
with run_with_timeout(5):
model.predict(data)
```
"""
assert _IS_UNIX, "Timeouts are not implemented yet for non-Unix platforms"

def signal_handler(signum, frame):
raise MLflowTimeoutError(f"Operation timed out after {seconds} seconds")

signal.signal(signal.SIGALRM, signal_handler)
signal.alarm(seconds)

try:
yield
finally:
signal.alarm(0) # Disable the alarm after the operation completes or times out
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

/ai are there any risks when we forcibly terminate operations like this function does?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@harupy

Yes, there are some risks when we forcibly terminate operations using a function like this.

  1. Data Consistency: If the block of code being executed involves altering some shared resource or global state, forcing a termination might leave that in an inconsistent state.

  2. Resource Leaks: Sometimes, a block of code may need to clean up resources (like open files, network connections, etc.) at the end of its execution. If it's forcibly terminated, these cleanup steps might be skipped, leading to resource leakages.

  3. Unpredictability: Since the block being forcefully terminated can contain any code, there's a risk that the handling of the TimeoutError may not be as intended. The programmer might not have considered the case where an operation could be terminated midway.

  4. Thread Safety: As the function documentation states, this implementation relies on Unix signals which is not thread-safe. If the context manager wraps code executing in multiple threads, the signal could be delivered to a thread other than the one which set the timer, leading to undefined behavior.

These and possibly more problems make the forcible termination of any operation a decision that should be considered carefully.

Click here to see the usage data
{
  "prompt_tokens": 249,
  "completion_tokens": 236,
  "total_tokens": 485,
  "estimated_cost_in_usd": 0.02163
}

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

/ai Does this disable the signal.alarm when the user interrupts while waiting?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@B-Step62

The provided code snippet doesn't provide a complete context.

import signal

It's just importing the signal module in Python, which provides mechanisms to handle different types of signals, including the SIGALRM signal for setting an alarm, and the SIGINT signal, which is typically generated by the user interrupting a process.

If you want to disable the signal.alarm() when the user interrupts while waiting, you should catch the SIGINT signal and disable the alarm in the signal handling function.

However, since no additional code/context is given, I cannot provide a specific answer to your question. Please provide more context or the complete piece of code that you're working with.

Click here to see the usage data
{
  "prompt_tokens": 314,
  "completion_tokens": 140,
  "total_tokens": 454,
  "estimated_cost_in_usd": 0.01782
}

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

/ai Does this code disable the signal.alarm when the user interrupts while waiting operation inside the context manager?

@contextmanager
def run_with_timeout(seconds):
    def signal_handler(signum, frame):
        raise MLflowTimeoutError(f"Operation timed out after {seconds} seconds")

    signal.signal(signal.SIGALRM, signal_handler)
    signal.alarm(seconds)

    try:
        yield
    finally:
        signal.alarm(0)  # Disable the alarm after the operation completes or times out

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@B-Step62

Yes, this code snippet does disable the alarm if the user interrupts the operation inside the context manager.

When the signal.alarm(0) function is executed in the finally block, it disables the alarm. The finally block is executed when the code within the try block either completes execution normally, or throws an exception due to user interruption or any other reason. So, even if a user interrupts the operation, the alarm would be disabled.

So in summary, regardless of how the operation within the context manager ends (even with an interruption), the alarm will be disabled.

Click here to see the usage data
{
  "prompt_tokens": 569,
  "completion_tokens": 121,
  "total_tokens": 690,
  "estimated_cost_in_usd": 0.02433
}

Signed-off-by: B-Step62 <yuki.watanabe@databricks.com>
@B-Step62
Copy link
Collaborator Author

B-Step62 commented Feb 7, 2024

Oookie, it appears that I finally resolved the failed test test_transformers_tf_model_log_without_conda_env_uses_default_env_with_expected_dependencies... Basically there are two layers of bugs. One caused the test failure in this PR, and the other has been hiding the bug - making the test pass until now even tho it shouldn't.

Prerequisite Knowledge: How we capture dependencies for Transformer models

Requirement inference is done by monitoring import event during the model loading (or model prediction if input_example is provided). For Transformer specifically, we repeat this for 3 times, to determine the model uses Tensorflow or Pytorch, or both (context). This is done by trial-and-error approach like following steps:

  1. Try loading model with setting USE_TORCH = True and validate if there is no tensorflow import => If this passes, we can say the model only depends on Pytorch.
  2. Then try loading with setting USE_TF = True and validate if there is no torch import => If this passes, we can say the model is only depends on Tensorflow.
  3. Finally, if both failed, we record both Tensorflow and Pytorch as required.

This seems to work, but the import capturing is not as straightforward as it looks, caused a few bugs.

Bug 1. Environment variable USE_TF / USE_TORCH is set after Transformers initiate the _torch_available flag.

Setting these environment variable is very important, as it does not only instruct the model to be loaded in the specified framework, but also prevent Transformers from importing the other librariy. For example, Transformer manages the binary state _torch_available that is used as a switch for many logic requireing Pytorch. This flag is set to True when Pytorch is installed, and USE_TF is not set to "True". As our test environment installs both Tensorflow and Pytorch, the USE_TF env variable is necessary to override this flag to False.

However, the issue was that the flag is set only once when the Transformer library is imported first time. Hence, it won't be flipped when we set the env var after first import. At present we are setting the env var when starting the import capturing (code), but indeed Transformer is imported earlier than that, and the flag is not set correctly.

Solution
To resolve the issue, this latest revision in this PR modifies the logic to set the environment variable when starting the subprocess for model loading.

However, then the question is why the test has not failed until this PR, which relates to the next bug.

Bug 2. Accelerate is installed for Tensorflow model and hides Pytorch in the logged requirements.

After MLflow captures all imported packages, MLflow doesn't use the list as they are. Instead, it trims down the list by remove packages that are installed by other packages anyway. For example, if the captured packges are ["scikit-learn", "numpy"], this pruning removes numpy because it is installed as a part of scikit-learn anyway.

What happened for the test before this PR is that, accelerate was captured as a model dependencies as well, and sinceit
has torch as its core dependency, torch is pruned and not listed in the final requirements.txt. As a result, the assertion assert "torch" not in ... did not fail so far. However, the model requirements actually include Pytorch indirectly via accelerate.

Solution
On the latest revions in this PR, somehow accelerate is no longer captured as dependencies for Tensorflow model. I couldn't spot out what was the trigger of this change, but most likely the similar internal state handling of Transformers. While it leaves a bit of ambiguity, the new behavior is correct because accelerate only support Pytorch models i.e. should not be logged for Tensorflow model. Pytorch model still logs accelerate as dependency when it is installed (validates with test_transformers_pt_model_save_without_conda_env_uses_default_env_with_expected_dependencies).

Copy link
Member

@BenWilson2 BenWilson2 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great investigation and fixes! LGTM!
As a final round of checks, let's validate model serving of a small toy TF version of a pipeline and a Torch with accelerate version, just to ensure that the modifications to inferred requirements work seamlessly with inference container build logic (it should work just fine, let's just make sure)

mlflow/utils/_capture_transformers_modules.py Outdated Show resolved Hide resolved
mlflow/utils/_capture_transformers_modules.py Outdated Show resolved Hide resolved
mlflow/utils/requirements_utils.py Show resolved Hide resolved
"Attempted to generate a signature for the saved model or pipeline "
f"but encountered an error: {e}"
)
raise
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we raise the exception or just return None?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should raise, cuz this case highly likely be an critical issue of model prediction that would cause same issue after loaded/served?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. Can we attempt to use the fallback in the case of any failure that occurs and only raise if a signature cannot be generated at all?

  • remove the raise
  • modify the warning in line 125 to raise an MLflowException

The reason being is that if a signature is not generated for these models on Databricks, they won't be eligible for registration in UC and won't be able to be submitted to model serving.

Copy link
Collaborator Author

@B-Step62 B-Step62 Feb 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, but should we allow all errors during signature inference? For example, we raise MlflowException when the given model is not Pipeline instance (L145). Also whatever happens in this prediction will happen in production after serving I guess. Solving an issue in model serving is kinda hard

_TransformersWrapper(
        pipeline=pipeline, model_config=model_config, flavor_config=flavor_config
    ).predict(data, params=params)

What about blocking for those exceptions at least to fail first, while allowing fall back for any errors from our code i.e. signature inference logic?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think if there is any errors during prediction result generation, the input example might be wrong (or the model has some problem), while the signature doesn't necessarily requires an output schema.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the input example might be wrong (or the model has some problem)

Yeah this is what I'm worried about, it's better telling users that "hey sth is wrong with your model or example". But I agree that the signature itself doesn't necessary need the output, so probably such validation is beyond the responsibility of this function. Will update to fallback not throwing (which I realize that same as what we do for requirement inference as well).

Copy link
Collaborator Author

@B-Step62 B-Step62 Feb 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reverted change for the case where no default fallback signature is found. If we raise an exception for such case, it prevent customers from saving custom pipeline class (and also caused failure for test case like test_invalid_task_inference_raises_error). While it might be not ideal for UC experience, I keep the original behavior i.e. just warn and return no signature, in the scope of this PR. I can do follow-up if necessary.

Copy link
Member

@harupy harupy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

Signed-off-by: B-Step62 <yuki.watanabe@databricks.com>
Copy link
Collaborator

@serena-ruan serena-ruan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

Signed-off-by: B-Step62 <yuki.watanabe@databricks.com>
@B-Step62 B-Step62 merged commit 8a72306 into mlflow:master Feb 8, 2024
61 checks passed
@B-Step62 B-Step62 deleted the fix-transformer-save-stuck branch February 8, 2024 10:58
sateeshmannar pushed a commit to StateFarmIns/mlflow that referenced this pull request Feb 20, 2024
…del logging. (mlflow#11037)

Signed-off-by: B-Step62 <yuki.watanabe@databricks.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
area/models MLmodel format, model serialization/deserialization, flavors rn/bug-fix Mention under Bug Fixes in Changelogs.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants