New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add timeout for signature/requirement inference during Transformer model logging. #11037
Conversation
Documentation preview for c92e64b will be available when this CircleCI job completes successfully. More info
|
overall it looks good to me, i think this is a great idea especially if the default signatures are good enough for typical use-cases! i agree that we probably wait for someone to raise a feature request if we want the timeout to be configurable, but i do think 60 seconds is a bit short based on my own usage of transformers (though not sure if i've been doing things in an optimized way). maybe we can increase the timeout? it looks like there are some test failures but happy to accept after those are resolved and if nobody else has any concerns! |
hmmm the failure of
|
mlflow/environment_variables.py
Outdated
#: Specifies the timeout for model inference with input example(s) when logging/saving a model. | ||
#: MLflow runs a few inference requests against the model to infer model signature and pip | ||
#: requirements. Sometimes the prediction hangs for a long time, especially for a large model. | ||
#: This timeout avoid the hanging and fall back to the default signature and pip requirements. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
#: This timeout avoid the hanging and fall back to the default signature and pip requirements. | |
#: This timeout limits the allowable time for performing a prediction for signature inference and will abort the prediction, falling back to the default signature and pip requirements. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How can we disable timeout? Setting this environment variable to 0?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we can set empty string or extremely large number like 99999
Signed-off-by: B-Step62 <yuki.watanabe@databricks.com>
Signed-off-by: B-Step62 <yuki.watanabe@databricks.com>
Signed-off-by: B-Step62 <yuki.watanabe@databricks.com>
Signed-off-by: B-Step62 <yuki.watanabe@databricks.com>
Signed-off-by: B-Step62 <yuki.watanabe@databricks.com>
Signed-off-by: B-Step62 <yuki.watanabe@databricks.com>
Signed-off-by: B-Step62 <yuki.watanabe@databricks.com>
Signed-off-by: B-Step62 <yuki.watanabe@databricks.com>
mlflow/utils/timeout.py
Outdated
@contextmanager | ||
def run_with_timeout(seconds): | ||
""" | ||
Context manager to runs a block of code with a timeout. If the block of code takes longer | ||
than `seconds` to execute, a `TimeoutError` is raised. | ||
NB: This function uses Unix signals to implement the timeout, so it is not thread-safe. | ||
Also it does not work on non-Unix platforms such as Windows. | ||
|
||
E.g. | ||
``` | ||
with run_with_timeout(5): | ||
model.predict(data) | ||
``` | ||
""" | ||
assert _IS_UNIX, "Timeouts are not implemented yet for non-Unix platforms" | ||
|
||
def signal_handler(signum, frame): | ||
raise MLflowTimeoutError(f"Operation timed out after {seconds} seconds") | ||
|
||
signal.signal(signal.SIGALRM, signal_handler) | ||
signal.alarm(seconds) | ||
|
||
try: | ||
yield | ||
finally: | ||
signal.alarm(0) # Disable the alarm after the operation completes or times out |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
/ai are there any risks when we forcibly terminate operations like this function does?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, there are some risks when we forcibly terminate operations using a function like this.
-
Data Consistency: If the block of code being executed involves altering some shared resource or global state, forcing a termination might leave that in an inconsistent state.
-
Resource Leaks: Sometimes, a block of code may need to clean up resources (like open files, network connections, etc.) at the end of its execution. If it's forcibly terminated, these cleanup steps might be skipped, leading to resource leakages.
-
Unpredictability: Since the block being forcefully terminated can contain any code, there's a risk that the handling of the TimeoutError may not be as intended. The programmer might not have considered the case where an operation could be terminated midway.
-
Thread Safety: As the function documentation states, this implementation relies on Unix signals which is not thread-safe. If the context manager wraps code executing in multiple threads, the signal could be delivered to a thread other than the one which set the timer, leading to undefined behavior.
These and possibly more problems make the forcible termination of any operation a decision that should be considered carefully.
Click here to see the usage data
{
"prompt_tokens": 249,
"completion_tokens": 236,
"total_tokens": 485,
"estimated_cost_in_usd": 0.02163
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
/ai Does this disable the signal.alarm when the user interrupts while waiting?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The provided code snippet doesn't provide a complete context.
import signal
It's just importing the signal
module in Python, which provides mechanisms to handle different types of signals, including the SIGALRM signal for setting an alarm, and the SIGINT signal, which is typically generated by the user interrupting a process.
If you want to disable the signal.alarm()
when the user interrupts while waiting, you should catch the SIGINT signal and disable the alarm in the signal handling function.
However, since no additional code/context is given, I cannot provide a specific answer to your question. Please provide more context or the complete piece of code that you're working with.
Click here to see the usage data
{
"prompt_tokens": 314,
"completion_tokens": 140,
"total_tokens": 454,
"estimated_cost_in_usd": 0.01782
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
/ai Does this code disable the signal.alarm when the user interrupts while waiting operation inside the context manager?
@contextmanager
def run_with_timeout(seconds):
def signal_handler(signum, frame):
raise MLflowTimeoutError(f"Operation timed out after {seconds} seconds")
signal.signal(signal.SIGALRM, signal_handler)
signal.alarm(seconds)
try:
yield
finally:
signal.alarm(0) # Disable the alarm after the operation completes or times out
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, this code snippet does disable the alarm if the user interrupts the operation inside the context manager.
When the signal.alarm(0)
function is executed in the finally
block, it disables the alarm. The finally
block is executed when the code within the try
block either completes execution normally, or throws an exception due to user interruption or any other reason. So, even if a user interrupts the operation, the alarm would be disabled.
So in summary, regardless of how the operation within the context manager ends (even with an interruption), the alarm will be disabled.
Click here to see the usage data
{
"prompt_tokens": 569,
"completion_tokens": 121,
"total_tokens": 690,
"estimated_cost_in_usd": 0.02433
}
1d05dbe
to
ca93374
Compare
Oookie, it appears that I finally resolved the failed test Prerequisite Knowledge: How we capture dependencies for Transformer modelsRequirement inference is done by monitoring
This seems to work, but the import capturing is not as straightforward as it looks, caused a few bugs. Bug 1. Environment variable
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great investigation and fixes! LGTM!
As a final round of checks, let's validate model serving of a small toy TF version of a pipeline and a Torch with accelerate version, just to ensure that the modifications to inferred requirements work seamlessly with inference container build logic (it should work just fine, let's just make sure)
mlflow/transformers/signature.py
Outdated
"Attempted to generate a signature for the saved model or pipeline " | ||
f"but encountered an error: {e}" | ||
) | ||
raise |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we raise the exception or just return None?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should raise, cuz this case highly likely be an critical issue of model prediction that would cause same issue after loaded/served?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point. Can we attempt to use the fallback in the case of any failure that occurs and only raise if a signature cannot be generated at all?
- remove the raise
- modify the warning in line 125 to raise an MLflowException
The reason being is that if a signature is not generated for these models on Databricks, they won't be eligible for registration in UC and won't be able to be submitted to model serving.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, but should we allow all errors during signature inference? For example, we raise MlflowException when the given model is not Pipeline instance (L145). Also whatever happens in this prediction will happen in production after serving I guess. Solving an issue in model serving is kinda hard
_TransformersWrapper(
pipeline=pipeline, model_config=model_config, flavor_config=flavor_config
).predict(data, params=params)
What about blocking for those exceptions at least to fail first, while allowing fall back for any errors from our code i.e. signature inference logic?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think if there is any errors during prediction result generation, the input example might be wrong (or the model has some problem), while the signature doesn't necessarily requires an output schema.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the input example might be wrong (or the model has some problem)
Yeah this is what I'm worried about, it's better telling users that "hey sth is wrong with your model or example". But I agree that the signature itself doesn't necessary need the output, so probably such validation is beyond the responsibility of this function. Will update to fallback not throwing (which I realize that same as what we do for requirement inference as well).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reverted change for the case where no default fallback signature is found. If we raise an exception for such case, it prevent customers from saving custom pipeline class (and also caused failure for test case like test_invalid_task_inference_raises_error
). While it might be not ideal for UC experience, I keep the original behavior i.e. just warn and return no signature, in the scope of this PR. I can do follow-up if necessary.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
Signed-off-by: B-Step62 <yuki.watanabe@databricks.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
…del logging. (mlflow#11037) Signed-off-by: B-Step62 <yuki.watanabe@databricks.com>
🛠 DevTools 🛠
Install mlflow from this PR
Checkout with GitHub CLI
What changes are proposed in this pull request?
Problem
Transformers model saving involves a few model predictions if an input example is provided.
However, this can take significant long for huge models like LLM, because we don't support saving models with optimized device mapping i.e. distributed to multiple devices. This is indeed a limitation of Transformers
save_pretrained
method we are using while logging context). As a result, users can only save models on single CPU/GPU, which can take hours or more for huge models.What did't work
Initially, I tried to solve this problem by allowing saving model with device_mapping, by porting model to single CPU/GPU just before calling
save_pretrained
. If this is possible, we can run prediction for signature/requirement inference with optimized device setting, which makes the latency acceptable short.However, it turns out moving models to different device is challenging, when it is originally distributed to multiple devices. For example, running
model.to(torch.device("cpu"))
will raiseRuntimeError: You can't move a model that has some modules offloaded to cpu or disk.
.We could implement device handling using native torch/tf libraries, but just found that there is a PR for extending
save_pretrained()
to support models loaded with device mapping. Hence, I think it we can wait this PR to be merged.What I did eventually
To mitigate the "stuck" issue, I just added timeout to those predictions.
This PR also includes small refactoring - extracting signature related logic to a separate file.
How is this PR tested?
Tested serving via Docker container (with accelerate)
Does this PR require documentation update?
Release Notes
Is this a user-facing change?
Update the Transformers log_model() API to enforce timeout for signature and requirement inference, with a static fallback, to avoid the hanging issue when saving large models.
What component(s), interfaces, languages, and integrations does this PR affect?
Components
area/artifacts
: Artifact stores and artifact loggingarea/build
: Build and test infrastructure for MLflowarea/deployments
: MLflow Deployments client APIs, server, and third-party Deployments integrationsarea/docs
: MLflow documentation pagesarea/examples
: Example codearea/model-registry
: Model Registry service, APIs, and the fluent client calls for Model Registryarea/models
: MLmodel format, model serialization/deserialization, flavorsarea/recipes
: Recipes, Recipe APIs, Recipe configs, Recipe Templatesarea/projects
: MLproject format, project running backendsarea/scoring
: MLflow Model server, model deployment tools, Spark UDFsarea/server-infra
: MLflow Tracking server backendarea/tracking
: Tracking Service, tracking client APIs, autologgingInterface
area/uiux
: Front-end, user experience, plotting, JavaScript, JavaScript dev serverarea/docker
: Docker use across MLflow's components, such as MLflow Projects and MLflow Modelsarea/sqlalchemy
: Use of SQLAlchemy in the Tracking Service or Model Registryarea/windows
: Windows supportLanguage
language/r
: R APIs and clientslanguage/java
: Java APIs and clientslanguage/new
: Proposals for new client languagesIntegrations
integrations/azure
: Azure and Azure ML integrationsintegrations/sagemaker
: SageMaker integrationsintegrations/databricks
: Databricks integrationsHow should the PR be classified in the release notes? Choose one:
rn/none
- No description will be included. The PR will be mentioned only by the PR number in the "Small Bugfixes and Documentation Updates" sectionrn/breaking-change
- The PR will be mentioned in the "Breaking Changes" sectionrn/feature
- A new user-facing feature worth mentioning in the release notesrn/bug-fix
- A user-facing bug fix worth mentioning in the release notesrn/documentation
- A user-facing documentation change worth mentioning in the release notes