Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update Inference Engine checkpoint loading + meta tensor assertions #2940

Merged
merged 12 commits into from
May 10, 2023

Conversation

lekurile
Copy link
Contributor

@lekurile lekurile commented Mar 3, 2023

This PR adds a model_device_meta attribute to the InferenceEngine that's used to:

  • assert that replace_with_kernel_inject == True if meta tensors are used, since we only support meta tensors when kernel injection is enabled.
  • Allow the InferenceEngine to load checkpoints via _load_checkpoint() when a checkpoint is passed to the init_inference API only when meta tensors are not used.

This PR also adds an assertion in the initialize_tensors function of the base container to check that if the model is using meta tensors, that the corresponding model container uses the meta tensor feature.

@lekurile lekurile marked this pull request as ready for review May 9, 2023 19:17
@lekurile lekurile changed the title Change init_inference checkpoint loading to explicitly check for meta tensor Update Inference Engine checkpoint loading + meta tensor assertions May 9, 2023
@@ -151,6 +151,13 @@ def __init__(self, model, config):
assert pkg_version.parse(torch.__version__) >= pkg_version.parse("1.10"), \
"If you want to use cuda graph, please upgrade torch to at least v1.10"

# Check if model passed to engine is loaded w/ meta tensors, in which case
# kernel injection must be enabled.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please update the comment to add why this only works for HF models.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a note saying that the device type is sourced assuming a Hugging Face hierarchy.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants