-
Notifications
You must be signed in to change notification settings - Fork 4.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update Inference Engine checkpoint loading + meta tensor assertions #2940
Conversation
@@ -151,6 +151,13 @@ def __init__(self, model, config): | |||
assert pkg_version.parse(torch.__version__) >= pkg_version.parse("1.10"), \ | |||
"If you want to use cuda graph, please upgrade torch to at least v1.10" | |||
|
|||
# Check if model passed to engine is loaded w/ meta tensors, in which case | |||
# kernel injection must be enabled. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please update the comment to add why this only works for HF models.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added a note saying that the device type is sourced assuming a Hugging Face hierarchy.
This PR adds a
model_device_meta
attribute to theInferenceEngine
that's used to:replace_with_kernel_inject == True
if meta tensors are used, since we only support meta tensors when kernel injection is enabled.InferenceEngine
to load checkpoints via_load_checkpoint()
when a checkpoint is passed to theinit_inference
API only when meta tensors are not used.This PR also adds an assertion in the
initialize_tensors
function of the base container to check that if the model is using meta tensors, that the corresponding model container uses the meta tensor feature.