Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Saving and Loading pytorch model as state dict #3705

Merged
merged 70 commits into from
Jan 12, 2021

Conversation

shrinath-suresh
Copy link
Contributor

Signed-off-by: Shrinath Suresh shrinath@ideas2it.com

What changes are proposed in this pull request?

The current implementation of mlflow.pytorch only supports for saving the entire model into mlflow. Adding support for saving and loading the model using state dict.

Instead of storing the entire model into mlflow, when the model state dicts are saved, the size of the model is reduced to a greater extent - which would be helpful during the deployment of the model.

#3408 - Please read through the discussion points on the PR . It would be helpful for the future use cases as mentioned above.

Implementation Details:

Adding two new methods to mlflow.pytorch - load_state_dict and save_state_dict for loading and saving the pytorch models. And also added a key state_dict under pytorch:flavor. By default(for entire model) the key will be set to false . Only when the model is saved/logged as state dict, the key would be set to true.

Sample screenshot given below

image

How is this patch tested?

Tested by saving/loading the model as both state dict and entire version. Working on the Unit tests.

Release Notes

Is this a user-facing change?

  • No. You can skip the rest of this section.
  • Yes. Give a description of this change to be included in the release notes for MLflow users.

(Details in 1-2 sentences. You can just refer to another PR with a description if this PR is part of a larger change.)

What component(s), interfaces, languages, and integrations does this PR affect?

Components

  • area/artifacts: Artifact stores and artifact logging
  • area/build: Build and test infrastructure for MLflow
  • area/docs: MLflow documentation pages
  • area/examples: Example code
  • area/model-registry: Model Registry service, APIs, and the fluent client calls for Model Registry
  • area/models: MLmodel format, model serialization/deserialization, flavors
  • area/projects: MLproject format, project running backends
  • area/scoring: Local serving, model deployment tools, spark UDFs
  • area/server-infra: MLflow server, JavaScript dev server
  • area/tracking: Tracking Service, tracking client APIs, autologging

Interface

  • area/uiux: Front-end, user experience, JavaScript, plotting
  • area/docker: Docker use across MLflow's components, such as MLflow Projects and MLflow Models
  • area/sqlalchemy: Use of SQLAlchemy in the Tracking Service or Model Registry
  • area/windows: Windows support

Language

  • language/r: R APIs and clients
  • language/java: Java APIs and clients
  • language/new: Proposals for new client languages

Integrations

  • integrations/azure: Azure and Azure ML integrations
  • integrations/sagemaker: SageMaker integrations
  • integrations/databricks: Databricks integrations

How should the PR be classified in the release notes? Choose one:

  • rn/breaking-change - The PR will be mentioned in the "Breaking Changes" section
  • rn/none - No description will be included. The PR will be mentioned only by the PR number in the "Small Bugfixes and Documentation Updates" section
  • rn/feature - A new user-facing feature worth mentioning in the release notes
  • rn/bug-fix - A user-facing bug fix worth mentioning in the release notes
  • rn/documentation - A user-facing documentation change worth mentioning in the release notes

…ibrary

Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>
@github-actions github-actions bot added area/models MLmodel format, model serialization/deserialization, flavors rn/feature Mention under Features in Changelogs. labels Nov 17, 2020
Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>
Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>
Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>
mlflow/pytorch/__init__.py Outdated Show resolved Hide resolved
mlflow/pytorch/__init__.py Outdated Show resolved Hide resolved
mlflow/pytorch/__init__.py Outdated Show resolved Hide resolved
…of the code in load_state_dict method

Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>
Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>
Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>
@harupy
Copy link
Member

harupy commented Nov 29, 2020

@shrinath-suresh Thanks for the updates! btw I have a question about whta you mentioned about saved model size.

Instead of storing the entire model into mlflow, when the model state dicts are saved, the size of the model is reduced to a greater extent - which would be helpful during the deployment of the model.

I wrote a simple script to verify this behavior.

import os
import torch
from torchvision import models
import shutil
import subprocess


SAVE_DIR = "foo"

if os.path.exists(SAVE_DIR):
    shutil.rmtree(SAVE_DIR)

os.makedirs(SAVE_DIR)


model = models.resnet50(pretrained=True)
torch.save(model, f"{SAVE_DIR}/model.pt")
torch.save(model.state_dict(), f"{SAVE_DIR}/state_dict.pt")

print(subprocess.check_output(["ls", "-lh", SAVE_DIR]).decode("utf-8"))

output

total 400624
-rw-r--r--  1 harutakakawamura  staff    98M Nov 29 21:56 model.pt
-rw-r--r--  1 harutakakawamura  staff    98M Nov 29 21:56 state_dict.pt
                                         ^^^
                                         almost equal

The difference between torch.save(model.state_dict(), ...) and torch.save(model, ...) is very small. Am I missing something?

@harupy
Copy link
Member

harupy commented Nov 29, 2020

WIP

(I'm writing this to consider the full design space in order to make the right decision on the API.)

How should we support state dicts?

Option 1:

log_state_dict logs a state dict in the specified artifact path (similar to log_artifact).

log_model(model, "model")  # -> saves model.pt
log_state_dict(model.state_dict(), "model")  # -> saves state_dict.pt (doesn't cerate or update an MLmodel file)

# --- output ---
# - model
#   - model.pt
#   - state_dict.pt
#   - MLmodel

Workflow to load the model in the TorchServe plugin:

path = _download_artifact_from_uri(model_uri)

if cotains_state_dict(path):
    serialized_file = os.path.join(path, mlflow.pytorch.STATE_DICT_FILENAME)
else:
    serialized_file = os.path.join(path, mlflow.pytorch.MODEL_FILENAME)

Pros:

  • log_state_dict doesn't need to create an MLmodel file.
  • Simpler implementation of log_state_dict

Cons:

  • Need to call both log_model and log_state_dict, which seems weird.
  • We need to handle two cases:
    1. case where log_state_dict is called
    2. case where log_state_dict is not called

Option 2:

Add a new flag argument save_state_dict (default: False) to mlflow.pytorch.log_model. If this value is set to True, log the state dict along with the pickled model.

Pros

  • ???

Cons

Option 3 (preferred):

log_state_dict(model.state_dict(), "model") logs a state dict and creates an MLmodel file with a new pytorch_state_dict flavor.

Workflow to load the model in the TorchServe plugin:

path = _download_artifact_from_uri(model_uri)
config = model = Model.load(os.path.join(path, "MLmodel"))

if "pytorch_state_dict" in config.flavors:
    serialized_file = os.path.join(path, mlflow.pytorch.STATE_DICT_FILENAME)
else:
    serialized_file = os.path.join(path, mlflow.pytorch.MODEL_FILENAME)

Pros:

  • Can only log a state dict.

Cons:

  • To allow serving the model, we need to log the model class and constructor parameters along with the state dict.
  • More maintenance burden for us (we need to maintain both log_model and log_state_dict)

Questions:

  • Should we allow serving? -> Ideally yes
  • Can we start without serving support? -> yes

APPENDIX

What do we need to recontruct a state dict model?

  • state dict
  • model class (= a python file that define the class)
  • constructor parameters (if the model class requires them)

How does TorchServe reconstruct a state dict model from model_file?

Does the torchserve plugin require an MLmodel file?

yes

What is MLflow Model?

Each MLflow Model is a directory containing arbitrary files, together with an MLmodel file in the root of the directory that can define multiple flavors that the model can be viewed in.

What is flavor?

Flavors are the key concept that makes MLflow Models powerful: they are a convention that deployment tools can use to understand the model, which makes it possible to write tools that work with models from any ML library without having to integrate each tool with each library.

https://www.mlflow.org/docs/latest/models.html#storage-format

Should we create a new flavor for log_state_dict rather than using the exisiting pytorch flavor?

Yes to make it easier for downstream tools (e.g. the TorchServer plugin) to understand what they can do with the model.

@shrinath-suresh Just feel free to add comments

@shrinath-suresh
Copy link
Contributor Author

@shrinath-suresh Thanks for the updates! btw I have a question about whta you mentioned about saved model size.

Instead of storing the entire model into mlflow, when the model state dicts are saved, the size of the model is reduced to a greater extent - which would be helpful during the deployment of the model.

I wrote a simple script to verify this behavior.

import os
import torch
from torchvision import models
import shutil
import subprocess


SAVE_DIR = "foo"

if os.path.exists(SAVE_DIR):
    shutil.rmtree(SAVE_DIR)

os.makedirs(SAVE_DIR)


model = models.resnet50(pretrained=True)
torch.save(model, f"{SAVE_DIR}/model.pt")
torch.save(model.state_dict(), f"{SAVE_DIR}/state_dict.pt")

print(subprocess.check_output(["ls", "-lh", SAVE_DIR]).decode("utf-8"))

output

total 400624
-rw-r--r--  1 harutakakawamura  staff    98M Nov 29 21:56 model.pt
-rw-r--r--  1 harutakakawamura  staff    98M Nov 29 21:56 state_dict.pt
                                         ^^^
                                         almost equal

The difference between torch.save(model.state_dict(), ...) and torch.save(model, ...) is very small. Am I missing something?

My observation is from MNIST example. I ran 10 epochs and here is the result of full model and state dict

-rw-rw-r--  1 ubuntu ubuntu 100M Nov 30 12:37 full_model.pth
-rw-rw-r--  1 ubuntu ubuntu 534K Nov 30 12:37 state_dict.pth

@harupy
Copy link
Member

harupy commented Nov 30, 2020

@shrinath-suresh This is probably because mlflow.pytorch.autolog saves trainer.model (which is a pl.LightningModule object). pl.LightningModule has many attributes that torch.nn.Module doesn't have. These attributes increase the saved model size.

@shrinath-suresh
Copy link
Contributor Author

shrinath-suresh commented Dec 1, 2020

@shrinath-suresh This is probably because mlflow.pytorch.autolog saves trainer.model (which is a pl.LightningModule object). pl.LightningModule has many attributes that torch.nn.Module doesn't have. These attributes increase the saved model size.

You are right. Same mnist example with pytorch shows same size for both state dict and entire model. We can take this discussion in a separate thread, as this PR has no dependency with mlflow.pytorch.autolog.

@harupy Do you have any more comments on the code ?

Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>
Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>
mlflow/pytorch/__init__.py Outdated Show resolved Hide resolved
…s to load the state dict

Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>
Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>
mlflow/pytorch/__init__.py Outdated Show resolved Hide resolved
with open(pickle_module_path, "w") as f:
f.write(pickle_module.__name__)

model_path = os.path.join(model_data_path, _SERIALIZED_TORCH_MODEL_FILE_NAME)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A uesr might log a state_dict that represents a checkpoint for inference and/or resuming training (this use case). In this case _SERIALIZED_TORCH_MODEL_FILE_NAME (= "model.pth") doesn't seem to be the right name because it's not a model.

Copy link
Member

@harupy harupy Dec 16, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe state_dict.pth is better?

Pro: easier to tell it's a state dict.
Con: harder to tell what the state dict represents.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

renamed it to state_dict.pth

mlflow/pytorch/__init__.py Outdated Show resolved Hide resolved
Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>
Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>
Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>
Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>
Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>
Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>
Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>
Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>
Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>
Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>
Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>
Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>
@harupy
Copy link
Member

harupy commented Jan 10, 2021

@shrinath-suresh I have pushed some commits to clean up the code :)

Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>
Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>
Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>
Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>
Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>
Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>
Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>
Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>
Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>
Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>
@shrinath-suresh
Copy link
Contributor Author

@harupy Thank you very much. The changes LGTM. Is there any other comment you have on this PR ? if not can we merge the PR?

Copy link
Member

@harupy harupy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@shrinath-suresh LGTM! Thanks for all the hard work 👍

Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>
@harupy harupy merged commit fcf8b90 into mlflow:master Jan 12, 2021
harupy added a commit to chauhang/mlflow that referenced this pull request Apr 8, 2021
* Adding save_state_dict and load_state_dict method to mlflow.pytorch library

Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>

* Removing unwanted changes

Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>

* Resetting empty lines

Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>

* Adding Unit tests for save_state_dict and load_state_dict model

Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>

* Adding log_state_dict method and refactored load_model to reuse most of the code in load_state_dict method

Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>

* Removing unused argument

Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>

* Applying black

Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>

* save_state_dict, log_state_dict and load_state_dict with pytorch flavor

Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>

* Removing MLModel file for state dict and adding appropriate conditions to load the state dict

Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>

* Updating doc strings

Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>

* Setting experimental annotation and saving state dict as state_dict.pth

Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>

* Fixing doc strings

Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>

* Removing state_dict key from save_model

Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>

* Addressing review comments

Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>

* Applying black

Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>

* Removing doc string

Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>

* swapping arguments

Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>

* Using get_artifact_uri to derive model path

Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>

* Removing pickle_module from save and log state dict

Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>

* rephrasing doc strings

Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>

* Renaming tests

Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>

* Comparing state dicts in test

Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>

* Disabling reimport error

Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>

* Removing blank line between params in doc string

Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>

* Removing model

Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>

* Replacing _get_model_artifact_path with _download_artifact_from_uri

Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>

* creating get_sequential_model utility and renamving model_class to model

Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>

* Removing pd.DataFrame type conversion

Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>

* Adding compare state dicts utility

Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>

* Removing Ordered Dictionary from doc string

Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>

* Fixing Docstring

Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>

* Removing unused variable

Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>

* Removing unused import

Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>

* Addressing review comments

Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>

* Removing unrelated change

Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>

* Addressing review comments

Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>

* Removing data folder

Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>

* Addressing review comments

Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>

* revert changes on load_model

Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>

* remove redundant folder generation

Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>

* Set exist_ok to True

Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>

* Assert state_dict is dict

Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>

* wording fix

Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>

* kwargs

Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>

* remove redundant model.eval

Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>

* fix

Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>

* Prevent false positive

Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>

* test for nested_state_dict

Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>

* blank line

Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>

* move tests

Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>

* put state dict functions in one place

Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>

* remove unused variable

Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>

* comment on test_save_state_dict_can_save_nested_state_dict

Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>

* Fix

Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>

* ensure model and optim can load state dict

Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>

* enhance comment

Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>

* comment

Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>

* dot

Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>

* remove useless comma

Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>

* use pos args

Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>

* rename

Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>

* nit

Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>

* article

Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>

* example

Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>

* Add checkpoint example

Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>

* remove ...

Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>

* warning

Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>

Co-authored-by: harupy <17039389+harupy@users.noreply.github.com>
Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>
harupy added a commit to wamartin-aml/mlflow that referenced this pull request Jun 7, 2021
* Adding save_state_dict and load_state_dict method to mlflow.pytorch library

Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>

* Removing unwanted changes

Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>

* Resetting empty lines

Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>

* Adding Unit tests for save_state_dict and load_state_dict model

Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>

* Adding log_state_dict method and refactored load_model to reuse most of the code in load_state_dict method

Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>

* Removing unused argument

Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>

* Applying black

Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>

* save_state_dict, log_state_dict and load_state_dict with pytorch flavor

Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>

* Removing MLModel file for state dict and adding appropriate conditions to load the state dict

Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>

* Updating doc strings

Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>

* Setting experimental annotation and saving state dict as state_dict.pth

Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>

* Fixing doc strings

Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>

* Removing state_dict key from save_model

Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>

* Addressing review comments

Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>

* Applying black

Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>

* Removing doc string

Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>

* swapping arguments

Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>

* Using get_artifact_uri to derive model path

Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>

* Removing pickle_module from save and log state dict

Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>

* rephrasing doc strings

Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>

* Renaming tests

Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>

* Comparing state dicts in test

Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>

* Disabling reimport error

Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>

* Removing blank line between params in doc string

Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>

* Removing model

Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>

* Replacing _get_model_artifact_path with _download_artifact_from_uri

Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>

* creating get_sequential_model utility and renamving model_class to model

Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>

* Removing pd.DataFrame type conversion

Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>

* Adding compare state dicts utility

Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>

* Removing Ordered Dictionary from doc string

Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>

* Fixing Docstring

Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>

* Removing unused variable

Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>

* Removing unused import

Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>

* Addressing review comments

Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>

* Removing unrelated change

Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>

* Addressing review comments

Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>

* Removing data folder

Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>

* Addressing review comments

Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>

* revert changes on load_model

Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>

* remove redundant folder generation

Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>

* Set exist_ok to True

Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>

* Assert state_dict is dict

Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>

* wording fix

Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>

* kwargs

Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>

* remove redundant model.eval

Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>

* fix

Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>

* Prevent false positive

Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>

* test for nested_state_dict

Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>

* blank line

Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>

* move tests

Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>

* put state dict functions in one place

Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>

* remove unused variable

Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>

* comment on test_save_state_dict_can_save_nested_state_dict

Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>

* Fix

Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>

* ensure model and optim can load state dict

Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>

* enhance comment

Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>

* comment

Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>

* dot

Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>

* remove useless comma

Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>

* use pos args

Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>

* rename

Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>

* nit

Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>

* article

Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>

* example

Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>

* Add checkpoint example

Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>

* remove ...

Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>

* warning

Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>

Co-authored-by: harupy <17039389+harupy@users.noreply.github.com>
Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
area/models MLmodel format, model serialization/deserialization, flavors rn/feature Mention under Features in Changelogs.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants