Skip to content

Commit

Permalink
Saving and Loading pytorch model as state dict (mlflow#3705)
Browse files Browse the repository at this point in the history
* Adding save_state_dict and load_state_dict method to mlflow.pytorch library

Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>

* Removing unwanted changes

Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>

* Resetting empty lines

Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>

* Adding Unit tests for save_state_dict and load_state_dict model

Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>

* Adding log_state_dict method and refactored load_model to reuse most of the code in load_state_dict method

Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>

* Removing unused argument

Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>

* Applying black

Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>

* save_state_dict, log_state_dict and load_state_dict with pytorch flavor

Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>

* Removing MLModel file for state dict and adding appropriate conditions to load the state dict

Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>

* Updating doc strings

Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>

* Setting experimental annotation and saving state dict as state_dict.pth

Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>

* Fixing doc strings

Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>

* Removing state_dict key from save_model

Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>

* Addressing review comments

Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>

* Applying black

Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>

* Removing doc string

Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>

* swapping arguments

Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>

* Using get_artifact_uri to derive model path

Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>

* Removing pickle_module from save and log state dict

Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>

* rephrasing doc strings

Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>

* Renaming tests

Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>

* Comparing state dicts in test

Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>

* Disabling reimport error

Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>

* Removing blank line between params in doc string

Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>

* Removing model

Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>

* Replacing _get_model_artifact_path with _download_artifact_from_uri

Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>

* creating get_sequential_model utility and renamving model_class to model

Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>

* Removing pd.DataFrame type conversion

Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>

* Adding compare state dicts utility

Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>

* Removing Ordered Dictionary from doc string

Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>

* Fixing Docstring

Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>

* Removing unused variable

Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>

* Removing unused import

Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>

* Addressing review comments

Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>

* Removing unrelated change

Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>

* Addressing review comments

Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>

* Removing data folder

Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>

* Addressing review comments

Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>

* revert changes on load_model

Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>

* remove redundant folder generation

Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>

* Set exist_ok to True

Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>

* Assert state_dict is dict

Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>

* wording fix

Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>

* kwargs

Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>

* remove redundant model.eval

Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>

* fix

Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>

* Prevent false positive

Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>

* test for nested_state_dict

Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>

* blank line

Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>

* move tests

Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>

* put state dict functions in one place

Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>

* remove unused variable

Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>

* comment on test_save_state_dict_can_save_nested_state_dict

Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>

* Fix

Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>

* ensure model and optim can load state dict

Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>

* enhance comment

Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>

* comment

Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>

* dot

Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>

* remove useless comma

Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>

* use pos args

Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>

* rename

Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>

* nit

Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>

* article

Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>

* example

Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>

* Add checkpoint example

Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>

* remove ...

Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>

* warning

Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>

Co-authored-by: harupy <17039389+harupy@users.noreply.github.com>
Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>
  • Loading branch information
shrinath-suresh and harupy committed Apr 8, 2021
1 parent 8874585 commit db361b2
Show file tree
Hide file tree
Showing 2 changed files with 189 additions and 1 deletion.
101 changes: 101 additions & 0 deletions mlflow/pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
FLAVOR_NAME = "pytorch"

_SERIALIZED_TORCH_MODEL_FILE_NAME = "model.pth"
_TORCH_STATE_DICT_FILE_NAME = "state_dict.pth"
_PICKLE_MODULE_INFO_FILE_NAME = "pickle_module_info.txt"
_EXTRA_FILES_KEY = "extra_files"
_REQUIREMENTS_FILE_KEY = "requirements_file"
Expand Down Expand Up @@ -725,6 +726,106 @@ def predict(self, data, device="cpu"):
return predicted


@experimental
def log_state_dict(state_dict, artifact_path, **kwargs):
"""
Log a state_dict as an MLflow artifact for the current run.
.. warning::
This function just logs a state_dict as an artifact and doesn't generate
an :ref:`MLflow Model <models>`.
:param state_dict: state_dict to be saved.
:param artifact_path: Run-relative artifact path.
:param kwargs: kwargs to pass to ``torch.save``.
.. code-block:: python
:caption: Example
# Log a model as a state_dict
with mlflow.start_run():
state_dict = model.state_dict()
mlflow.pytorch.log_state_dict(state_dict, artifact_path="model")
# Log a checkpoint as a state_dict
with mlflow.start_run():
state_dict = {
"model": model.state_dict(),
"optimizer": optimizer.state_dict(),
"epoch": epoch,
"loss": loss,
}
mlflow.pytorch.log_state_dict(state_dict, artifact_path="checkpoint")
"""

with TempDir() as tmp:
local_path = tmp.path()
save_state_dict(state_dict=state_dict, path=local_path, **kwargs)
mlflow.log_artifacts(local_path, artifact_path)


@experimental
def save_state_dict(state_dict, path, **kwargs):
"""
Save a state_dict to a path on the local file system
:param state_dict: state_dict to be saved.
:param path: Local path where the state_dict is to be saved.
:param kwargs: kwargs to pass to ``torch.save``.
"""
import torch

# The object type check here aims to prevent a scenario where a user accidentally passees
# a model instead of a state_dict and `torch.save` (which accepts both model and state_dict)
# successfully completes, leaving the user unaware of the mistake.
if not isinstance(state_dict, dict):
raise TypeError(
"Invalid object type for `state_dict`: {}. Must be an instance of `dict`".format(
type(state_dict)
)
)

os.makedirs(path, exist_ok=True)
state_dict_path = os.path.join(path, _TORCH_STATE_DICT_FILE_NAME)
torch.save(state_dict, state_dict_path, **kwargs)


@experimental
def load_state_dict(state_dict_uri, **kwargs):
"""
Load a state_dict from a local file or a run.
:param state_dict_uri: The location, in URI format, of the state_dict, for example:
- ``/Users/me/path/to/local/state_dict``
- ``relative/path/to/local/state_dict``
- ``s3://my_bucket/path/to/state_dict``
- ``runs:/<mlflow_run_id>/run-relative/path/to/state_dict``
For more information about supported URI schemes, see
`Referencing Artifacts <https://www.mlflow.org/docs/latest/concepts.html#
artifact-locations>`_.
:param kwargs: kwargs to pass to ``torch.load``.
:return: A state_dict
.. code-block:: python
:caption: Example
with mlflow.start_run():
artifact_path = "model"
mlflow.pytorch.log_state_dict(model.state_dict(), artifact_path)
state_dict_uri = mlflow.get_artifact_uri(artifact_path)
state_dict = mlflow.pytorch.load_state_dict(state_dict_uri)
"""
import torch

local_path = _download_artifact_from_uri(artifact_uri=state_dict_uri)
state_dict_path = os.path.join(local_path, _TORCH_STATE_DICT_FILE_NAME)
return torch.load(state_dict_path, **kwargs)


@experimental
@autologging_integration(FLAVOR_NAME)
def autolog(
Expand Down
89 changes: 88 additions & 1 deletion tests/pytorch/test_pytorch_model_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,13 @@ def train_model(model, data):
optimizer.step()


def get_sequential_model():
return nn.Sequential(nn.Linear(4, 3), nn.ReLU(), nn.Linear(3, 1))


@pytest.fixture
def sequential_model(data, scripted_model):
model = nn.Sequential(nn.Linear(4, 3), nn.ReLU(), nn.Linear(3, 1),)
model = get_sequential_model()
if scripted_model:
model = torch.jit.script(model)

Expand Down Expand Up @@ -1032,3 +1036,86 @@ def test_log_model_invalid_extra_file_type(sequential_model):
conda_env=None,
extra_files="inexistent_file.txt",
)


def state_dict_equal(state_dict1, state_dict2):
for key1 in state_dict1:
if key1 not in state_dict2:
return False

value1 = state_dict1[key1]
value2 = state_dict2[key1]

if type(value1) != type(value2):
return False
elif isinstance(value1, dict):
if not state_dict_equal(value1, value2):
return False
elif isinstance(value1, torch.Tensor):
if not torch.equal(value1, value2):
return False
elif value1 != value2:
return False
else:
continue

return True


@pytest.mark.large
@pytest.mark.parametrize("scripted_model", [True, False])
def test_save_state_dict(sequential_model, model_path, data):
state_dict = sequential_model.state_dict()
mlflow.pytorch.save_state_dict(state_dict, model_path)

loaded_state_dict = mlflow.pytorch.load_state_dict(model_path)
assert state_dict_equal(loaded_state_dict, state_dict)
model = get_sequential_model()
model.load_state_dict(loaded_state_dict)
np.testing.assert_array_almost_equal(
_predict(model, data), _predict(sequential_model, data), decimal=4,
)


@pytest.mark.large
def test_save_state_dict_can_save_nested_state_dict(model_path):
"""
This test ensures that `save_state_dict` supports a use case described in the page below
where a user bundles multiple objects (e.g., model, optimizer, learning-rate scheduler)
into a single nested state_dict and loads it back later for inference or re-training:
https://pytorch.org/tutorials/recipes/recipes/saving_and_loading_a_general_checkpoint.html
"""
model = get_sequential_model()
optim = torch.optim.Adam(model.parameters())
state_dict = {"model": model.state_dict(), "optim": optim.state_dict()}
mlflow.pytorch.save_state_dict(state_dict, model_path)

loaded_state_dict = mlflow.pytorch.load_state_dict(model_path)
assert state_dict_equal(loaded_state_dict, state_dict)
model.load_state_dict(loaded_state_dict["model"])
optim.load_state_dict(loaded_state_dict["optim"])


@pytest.mark.large
@pytest.mark.parametrize("not_state_dict", [0, "", get_sequential_model()])
def test_save_state_dict_throws_for_invalid_object_type(not_state_dict, model_path):
with pytest.raises(TypeError, match="Invalid object type for `state_dict`"):
mlflow.pytorch.save_state_dict(not_state_dict, model_path)


@pytest.mark.large
@pytest.mark.parametrize("scripted_model", [True, False])
def test_log_state_dict(sequential_model, data):
artifact_path = "model"
state_dict = sequential_model.state_dict()
with mlflow.start_run():
mlflow.pytorch.log_state_dict(state_dict, artifact_path)
state_dict_uri = mlflow.get_artifact_uri(artifact_path)

loaded_state_dict = mlflow.pytorch.load_state_dict(state_dict_uri)
assert state_dict_equal(loaded_state_dict, state_dict)
model = get_sequential_model()
model.load_state_dict(loaded_state_dict)
np.testing.assert_array_almost_equal(
_predict(model, data), _predict(sequential_model, data), decimal=4,
)

0 comments on commit db361b2

Please sign in to comment.