-
Notifications
You must be signed in to change notification settings - Fork 4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Saving and Loading pytorch model as state dict #3705
Saving and Loading pytorch model as state dict #3705
Conversation
…ibrary Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>
Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>
Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>
Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>
…of the code in load_state_dict method Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>
Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>
Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>
@shrinath-suresh Thanks for the updates! btw I have a question about whta you mentioned about saved model size.
I wrote a simple script to verify this behavior. import os
import torch
from torchvision import models
import shutil
import subprocess
SAVE_DIR = "foo"
if os.path.exists(SAVE_DIR):
shutil.rmtree(SAVE_DIR)
os.makedirs(SAVE_DIR)
model = models.resnet50(pretrained=True)
torch.save(model, f"{SAVE_DIR}/model.pt")
torch.save(model.state_dict(), f"{SAVE_DIR}/state_dict.pt")
print(subprocess.check_output(["ls", "-lh", SAVE_DIR]).decode("utf-8")) output
The difference between |
WIP(I'm writing this to consider the full design space in order to make the right decision on the API.) How should we support state dicts?Option 1:
log_model(model, "model") # -> saves model.pt
log_state_dict(model.state_dict(), "model") # -> saves state_dict.pt (doesn't cerate or update an MLmodel file)
# --- output ---
# - model
# - model.pt
# - state_dict.pt
# - MLmodel Workflow to load the model in the TorchServe plugin: path = _download_artifact_from_uri(model_uri)
if cotains_state_dict(path):
serialized_file = os.path.join(path, mlflow.pytorch.STATE_DICT_FILENAME)
else:
serialized_file = os.path.join(path, mlflow.pytorch.MODEL_FILENAME) Pros:
Cons:
Option 2:Add a new flag argument Pros
Cons
Option 3 (preferred):
Workflow to load the model in the TorchServe plugin: path = _download_artifact_from_uri(model_uri)
config = model = Model.load(os.path.join(path, "MLmodel"))
if "pytorch_state_dict" in config.flavors:
serialized_file = os.path.join(path, mlflow.pytorch.STATE_DICT_FILENAME)
else:
serialized_file = os.path.join(path, mlflow.pytorch.MODEL_FILENAME) Pros:
Cons:
Questions:
APPENDIXWhat do we need to recontruct a state dict model?
How does TorchServe reconstruct a state dict model from
|
My observation is from MNIST example. I ran 10 epochs and here is the result of full model and state dict
|
@shrinath-suresh This is probably because |
You are right. Same mnist example with pytorch shows same size for both state dict and entire model. We can take this discussion in a separate thread, as this PR has no dependency with @harupy Do you have any more comments on the code ? |
Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>
Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>
…s to load the state dict Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>
Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>
mlflow/pytorch/__init__.py
Outdated
with open(pickle_module_path, "w") as f: | ||
f.write(pickle_module.__name__) | ||
|
||
model_path = os.path.join(model_data_path, _SERIALIZED_TORCH_MODEL_FILE_NAME) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A uesr might log a state_dict that represents a checkpoint for inference and/or resuming training (this use case). In this case _SERIALIZED_TORCH_MODEL_FILE_NAME
(= "model.pth"
) doesn't seem to be the right name because it's not a model.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe state_dict.pth
is better?
Pro: easier to tell it's a state dict.
Con: harder to tell what the state dict represents.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
renamed it to state_dict.pth
Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>
Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>
Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>
Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>
Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>
Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>
Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>
Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>
Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>
Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>
@shrinath-suresh I have pushed some commits to clean up the code :) |
Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>
Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>
Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>
Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>
@harupy Thank you very much. The changes LGTM. Is there any other comment you have on this PR ? if not can we merge the PR? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@shrinath-suresh LGTM! Thanks for all the hard work 👍
* Adding save_state_dict and load_state_dict method to mlflow.pytorch library Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com> * Removing unwanted changes Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com> * Resetting empty lines Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com> * Adding Unit tests for save_state_dict and load_state_dict model Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com> * Adding log_state_dict method and refactored load_model to reuse most of the code in load_state_dict method Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com> * Removing unused argument Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com> * Applying black Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com> * save_state_dict, log_state_dict and load_state_dict with pytorch flavor Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com> * Removing MLModel file for state dict and adding appropriate conditions to load the state dict Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com> * Updating doc strings Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com> * Setting experimental annotation and saving state dict as state_dict.pth Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com> * Fixing doc strings Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com> * Removing state_dict key from save_model Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com> * Addressing review comments Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com> * Applying black Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com> * Removing doc string Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com> * swapping arguments Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com> * Using get_artifact_uri to derive model path Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com> * Removing pickle_module from save and log state dict Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com> * rephrasing doc strings Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com> * Renaming tests Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com> * Comparing state dicts in test Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com> * Disabling reimport error Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com> * Removing blank line between params in doc string Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com> * Removing model Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com> * Replacing _get_model_artifact_path with _download_artifact_from_uri Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com> * creating get_sequential_model utility and renamving model_class to model Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com> * Removing pd.DataFrame type conversion Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com> * Adding compare state dicts utility Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com> * Removing Ordered Dictionary from doc string Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com> * Fixing Docstring Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com> * Removing unused variable Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com> * Removing unused import Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com> * Addressing review comments Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com> * Removing unrelated change Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com> * Addressing review comments Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com> * Removing data folder Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com> * Addressing review comments Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com> * revert changes on load_model Signed-off-by: harupy <17039389+harupy@users.noreply.github.com> * remove redundant folder generation Signed-off-by: harupy <17039389+harupy@users.noreply.github.com> * Set exist_ok to True Signed-off-by: harupy <17039389+harupy@users.noreply.github.com> * Assert state_dict is dict Signed-off-by: harupy <17039389+harupy@users.noreply.github.com> * wording fix Signed-off-by: harupy <17039389+harupy@users.noreply.github.com> * kwargs Signed-off-by: harupy <17039389+harupy@users.noreply.github.com> * remove redundant model.eval Signed-off-by: harupy <17039389+harupy@users.noreply.github.com> * fix Signed-off-by: harupy <17039389+harupy@users.noreply.github.com> * Prevent false positive Signed-off-by: harupy <17039389+harupy@users.noreply.github.com> * test for nested_state_dict Signed-off-by: harupy <17039389+harupy@users.noreply.github.com> * blank line Signed-off-by: harupy <17039389+harupy@users.noreply.github.com> * move tests Signed-off-by: harupy <17039389+harupy@users.noreply.github.com> * put state dict functions in one place Signed-off-by: harupy <17039389+harupy@users.noreply.github.com> * remove unused variable Signed-off-by: harupy <17039389+harupy@users.noreply.github.com> * comment on test_save_state_dict_can_save_nested_state_dict Signed-off-by: harupy <17039389+harupy@users.noreply.github.com> * Fix Signed-off-by: harupy <17039389+harupy@users.noreply.github.com> * ensure model and optim can load state dict Signed-off-by: harupy <17039389+harupy@users.noreply.github.com> * enhance comment Signed-off-by: harupy <17039389+harupy@users.noreply.github.com> * comment Signed-off-by: harupy <17039389+harupy@users.noreply.github.com> * dot Signed-off-by: harupy <17039389+harupy@users.noreply.github.com> * remove useless comma Signed-off-by: harupy <17039389+harupy@users.noreply.github.com> * use pos args Signed-off-by: harupy <17039389+harupy@users.noreply.github.com> * rename Signed-off-by: harupy <17039389+harupy@users.noreply.github.com> * nit Signed-off-by: harupy <17039389+harupy@users.noreply.github.com> * article Signed-off-by: harupy <17039389+harupy@users.noreply.github.com> * example Signed-off-by: harupy <17039389+harupy@users.noreply.github.com> * Add checkpoint example Signed-off-by: harupy <17039389+harupy@users.noreply.github.com> * remove ... Signed-off-by: harupy <17039389+harupy@users.noreply.github.com> * warning Signed-off-by: harupy <17039389+harupy@users.noreply.github.com> Co-authored-by: harupy <17039389+harupy@users.noreply.github.com> Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>
* Adding save_state_dict and load_state_dict method to mlflow.pytorch library Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com> * Removing unwanted changes Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com> * Resetting empty lines Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com> * Adding Unit tests for save_state_dict and load_state_dict model Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com> * Adding log_state_dict method and refactored load_model to reuse most of the code in load_state_dict method Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com> * Removing unused argument Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com> * Applying black Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com> * save_state_dict, log_state_dict and load_state_dict with pytorch flavor Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com> * Removing MLModel file for state dict and adding appropriate conditions to load the state dict Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com> * Updating doc strings Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com> * Setting experimental annotation and saving state dict as state_dict.pth Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com> * Fixing doc strings Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com> * Removing state_dict key from save_model Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com> * Addressing review comments Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com> * Applying black Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com> * Removing doc string Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com> * swapping arguments Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com> * Using get_artifact_uri to derive model path Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com> * Removing pickle_module from save and log state dict Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com> * rephrasing doc strings Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com> * Renaming tests Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com> * Comparing state dicts in test Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com> * Disabling reimport error Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com> * Removing blank line between params in doc string Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com> * Removing model Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com> * Replacing _get_model_artifact_path with _download_artifact_from_uri Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com> * creating get_sequential_model utility and renamving model_class to model Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com> * Removing pd.DataFrame type conversion Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com> * Adding compare state dicts utility Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com> * Removing Ordered Dictionary from doc string Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com> * Fixing Docstring Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com> * Removing unused variable Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com> * Removing unused import Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com> * Addressing review comments Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com> * Removing unrelated change Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com> * Addressing review comments Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com> * Removing data folder Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com> * Addressing review comments Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com> * revert changes on load_model Signed-off-by: harupy <17039389+harupy@users.noreply.github.com> * remove redundant folder generation Signed-off-by: harupy <17039389+harupy@users.noreply.github.com> * Set exist_ok to True Signed-off-by: harupy <17039389+harupy@users.noreply.github.com> * Assert state_dict is dict Signed-off-by: harupy <17039389+harupy@users.noreply.github.com> * wording fix Signed-off-by: harupy <17039389+harupy@users.noreply.github.com> * kwargs Signed-off-by: harupy <17039389+harupy@users.noreply.github.com> * remove redundant model.eval Signed-off-by: harupy <17039389+harupy@users.noreply.github.com> * fix Signed-off-by: harupy <17039389+harupy@users.noreply.github.com> * Prevent false positive Signed-off-by: harupy <17039389+harupy@users.noreply.github.com> * test for nested_state_dict Signed-off-by: harupy <17039389+harupy@users.noreply.github.com> * blank line Signed-off-by: harupy <17039389+harupy@users.noreply.github.com> * move tests Signed-off-by: harupy <17039389+harupy@users.noreply.github.com> * put state dict functions in one place Signed-off-by: harupy <17039389+harupy@users.noreply.github.com> * remove unused variable Signed-off-by: harupy <17039389+harupy@users.noreply.github.com> * comment on test_save_state_dict_can_save_nested_state_dict Signed-off-by: harupy <17039389+harupy@users.noreply.github.com> * Fix Signed-off-by: harupy <17039389+harupy@users.noreply.github.com> * ensure model and optim can load state dict Signed-off-by: harupy <17039389+harupy@users.noreply.github.com> * enhance comment Signed-off-by: harupy <17039389+harupy@users.noreply.github.com> * comment Signed-off-by: harupy <17039389+harupy@users.noreply.github.com> * dot Signed-off-by: harupy <17039389+harupy@users.noreply.github.com> * remove useless comma Signed-off-by: harupy <17039389+harupy@users.noreply.github.com> * use pos args Signed-off-by: harupy <17039389+harupy@users.noreply.github.com> * rename Signed-off-by: harupy <17039389+harupy@users.noreply.github.com> * nit Signed-off-by: harupy <17039389+harupy@users.noreply.github.com> * article Signed-off-by: harupy <17039389+harupy@users.noreply.github.com> * example Signed-off-by: harupy <17039389+harupy@users.noreply.github.com> * Add checkpoint example Signed-off-by: harupy <17039389+harupy@users.noreply.github.com> * remove ... Signed-off-by: harupy <17039389+harupy@users.noreply.github.com> * warning Signed-off-by: harupy <17039389+harupy@users.noreply.github.com> Co-authored-by: harupy <17039389+harupy@users.noreply.github.com> Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>
Signed-off-by: Shrinath Suresh shrinath@ideas2it.com
What changes are proposed in this pull request?
The current implementation of
mlflow.pytorch
only supports for saving the entire model into mlflow. Adding support for saving and loading the model using state dict.Instead of storing the entire model into mlflow, when the model state dicts are saved, the size of the model is reduced to a greater extent - which would be helpful during the deployment of the model.
#3408 - Please read through the discussion points on the PR . It would be helpful for the future use cases as mentioned above.
Implementation Details:
Adding two new methods to mlflow.pytorch -
load_state_dict
andsave_state_dict
for loading and saving the pytorch models. And also added a keystate_dict
underpytorch:flavor
. By default(for entire model) the key will be set tofalse
. Only when the model is saved/logged as state dict, the key would be set totrue
.Sample screenshot given below
How is this patch tested?
Tested by saving/loading the model as both state dict and entire version. Working on the Unit tests.
Release Notes
Is this a user-facing change?
(Details in 1-2 sentences. You can just refer to another PR with a description if this PR is part of a larger change.)
What component(s), interfaces, languages, and integrations does this PR affect?
Components
area/artifacts
: Artifact stores and artifact loggingarea/build
: Build and test infrastructure for MLflowarea/docs
: MLflow documentation pagesarea/examples
: Example codearea/model-registry
: Model Registry service, APIs, and the fluent client calls for Model Registryarea/models
: MLmodel format, model serialization/deserialization, flavorsarea/projects
: MLproject format, project running backendsarea/scoring
: Local serving, model deployment tools, spark UDFsarea/server-infra
: MLflow server, JavaScript dev serverarea/tracking
: Tracking Service, tracking client APIs, autologgingInterface
area/uiux
: Front-end, user experience, JavaScript, plottingarea/docker
: Docker use across MLflow's components, such as MLflow Projects and MLflow Modelsarea/sqlalchemy
: Use of SQLAlchemy in the Tracking Service or Model Registryarea/windows
: Windows supportLanguage
language/r
: R APIs and clientslanguage/java
: Java APIs and clientslanguage/new
: Proposals for new client languagesIntegrations
integrations/azure
: Azure and Azure ML integrationsintegrations/sagemaker
: SageMaker integrationsintegrations/databricks
: Databricks integrationsHow should the PR be classified in the release notes? Choose one:
rn/breaking-change
- The PR will be mentioned in the "Breaking Changes" sectionrn/none
- No description will be included. The PR will be mentioned only by the PR number in the "Small Bugfixes and Documentation Updates" sectionrn/feature
- A new user-facing feature worth mentioning in the release notesrn/bug-fix
- A user-facing bug fix worth mentioning in the release notesrn/documentation
- A user-facing documentation change worth mentioning in the release notes