diff --git a/mlflow/pytorch/__init__.py b/mlflow/pytorch/__init__.py index 9fa81dd41e29e..e7f2afed1fbab 100644 --- a/mlflow/pytorch/__init__.py +++ b/mlflow/pytorch/__init__.py @@ -39,6 +39,7 @@ FLAVOR_NAME = "pytorch" _SERIALIZED_TORCH_MODEL_FILE_NAME = "model.pth" +_TORCH_STATE_DICT_FILE_NAME = "state_dict.pth" _PICKLE_MODULE_INFO_FILE_NAME = "pickle_module_info.txt" _EXTRA_FILES_KEY = "extra_files" _REQUIREMENTS_FILE_KEY = "requirements_file" @@ -725,6 +726,106 @@ def predict(self, data, device="cpu"): return predicted +@experimental +def log_state_dict(state_dict, artifact_path, **kwargs): + """ + Log a state_dict as an MLflow artifact for the current run. + + .. warning:: + This function just logs a state_dict as an artifact and doesn't generate + an :ref:`MLflow Model `. + + :param state_dict: state_dict to be saved. + :param artifact_path: Run-relative artifact path. + :param kwargs: kwargs to pass to ``torch.save``. + + .. code-block:: python + :caption: Example + + # Log a model as a state_dict + with mlflow.start_run(): + state_dict = model.state_dict() + mlflow.pytorch.log_state_dict(state_dict, artifact_path="model") + + # Log a checkpoint as a state_dict + with mlflow.start_run(): + state_dict = { + "model": model.state_dict(), + "optimizer": optimizer.state_dict(), + "epoch": epoch, + "loss": loss, + } + mlflow.pytorch.log_state_dict(state_dict, artifact_path="checkpoint") + """ + + with TempDir() as tmp: + local_path = tmp.path() + save_state_dict(state_dict=state_dict, path=local_path, **kwargs) + mlflow.log_artifacts(local_path, artifact_path) + + +@experimental +def save_state_dict(state_dict, path, **kwargs): + """ + Save a state_dict to a path on the local file system + + :param state_dict: state_dict to be saved. + :param path: Local path where the state_dict is to be saved. + :param kwargs: kwargs to pass to ``torch.save``. + """ + import torch + + # The object type check here aims to prevent a scenario where a user accidentally passees + # a model instead of a state_dict and `torch.save` (which accepts both model and state_dict) + # successfully completes, leaving the user unaware of the mistake. + if not isinstance(state_dict, dict): + raise TypeError( + "Invalid object type for `state_dict`: {}. Must be an instance of `dict`".format( + type(state_dict) + ) + ) + + os.makedirs(path, exist_ok=True) + state_dict_path = os.path.join(path, _TORCH_STATE_DICT_FILE_NAME) + torch.save(state_dict, state_dict_path, **kwargs) + + +@experimental +def load_state_dict(state_dict_uri, **kwargs): + """ + Load a state_dict from a local file or a run. + + :param state_dict_uri: The location, in URI format, of the state_dict, for example: + + - ``/Users/me/path/to/local/state_dict`` + - ``relative/path/to/local/state_dict`` + - ``s3://my_bucket/path/to/state_dict`` + - ``runs://run-relative/path/to/state_dict`` + + For more information about supported URI schemes, see + `Referencing Artifacts `_. + + :param kwargs: kwargs to pass to ``torch.load``. + :return: A state_dict + + .. code-block:: python + :caption: Example + + with mlflow.start_run(): + artifact_path = "model" + mlflow.pytorch.log_state_dict(model.state_dict(), artifact_path) + state_dict_uri = mlflow.get_artifact_uri(artifact_path) + + state_dict = mlflow.pytorch.load_state_dict(state_dict_uri) + """ + import torch + + local_path = _download_artifact_from_uri(artifact_uri=state_dict_uri) + state_dict_path = os.path.join(local_path, _TORCH_STATE_DICT_FILE_NAME) + return torch.load(state_dict_path, **kwargs) + + @experimental @autologging_integration(FLAVOR_NAME) def autolog( diff --git a/tests/pytorch/test_pytorch_model_export.py b/tests/pytorch/test_pytorch_model_export.py index 77bf6cc51d204..d6b3ecd9ab05d 100644 --- a/tests/pytorch/test_pytorch_model_export.py +++ b/tests/pytorch/test_pytorch_model_export.py @@ -84,9 +84,13 @@ def train_model(model, data): optimizer.step() +def get_sequential_model(): + return nn.Sequential(nn.Linear(4, 3), nn.ReLU(), nn.Linear(3, 1)) + + @pytest.fixture def sequential_model(data, scripted_model): - model = nn.Sequential(nn.Linear(4, 3), nn.ReLU(), nn.Linear(3, 1),) + model = get_sequential_model() if scripted_model: model = torch.jit.script(model) @@ -1032,3 +1036,86 @@ def test_log_model_invalid_extra_file_type(sequential_model): conda_env=None, extra_files="inexistent_file.txt", ) + + +def state_dict_equal(state_dict1, state_dict2): + for key1 in state_dict1: + if key1 not in state_dict2: + return False + + value1 = state_dict1[key1] + value2 = state_dict2[key1] + + if type(value1) != type(value2): + return False + elif isinstance(value1, dict): + if not state_dict_equal(value1, value2): + return False + elif isinstance(value1, torch.Tensor): + if not torch.equal(value1, value2): + return False + elif value1 != value2: + return False + else: + continue + + return True + + +@pytest.mark.large +@pytest.mark.parametrize("scripted_model", [True, False]) +def test_save_state_dict(sequential_model, model_path, data): + state_dict = sequential_model.state_dict() + mlflow.pytorch.save_state_dict(state_dict, model_path) + + loaded_state_dict = mlflow.pytorch.load_state_dict(model_path) + assert state_dict_equal(loaded_state_dict, state_dict) + model = get_sequential_model() + model.load_state_dict(loaded_state_dict) + np.testing.assert_array_almost_equal( + _predict(model, data), _predict(sequential_model, data), decimal=4, + ) + + +@pytest.mark.large +def test_save_state_dict_can_save_nested_state_dict(model_path): + """ + This test ensures that `save_state_dict` supports a use case described in the page below + where a user bundles multiple objects (e.g., model, optimizer, learning-rate scheduler) + into a single nested state_dict and loads it back later for inference or re-training: + https://pytorch.org/tutorials/recipes/recipes/saving_and_loading_a_general_checkpoint.html + """ + model = get_sequential_model() + optim = torch.optim.Adam(model.parameters()) + state_dict = {"model": model.state_dict(), "optim": optim.state_dict()} + mlflow.pytorch.save_state_dict(state_dict, model_path) + + loaded_state_dict = mlflow.pytorch.load_state_dict(model_path) + assert state_dict_equal(loaded_state_dict, state_dict) + model.load_state_dict(loaded_state_dict["model"]) + optim.load_state_dict(loaded_state_dict["optim"]) + + +@pytest.mark.large +@pytest.mark.parametrize("not_state_dict", [0, "", get_sequential_model()]) +def test_save_state_dict_throws_for_invalid_object_type(not_state_dict, model_path): + with pytest.raises(TypeError, match="Invalid object type for `state_dict`"): + mlflow.pytorch.save_state_dict(not_state_dict, model_path) + + +@pytest.mark.large +@pytest.mark.parametrize("scripted_model", [True, False]) +def test_log_state_dict(sequential_model, data): + artifact_path = "model" + state_dict = sequential_model.state_dict() + with mlflow.start_run(): + mlflow.pytorch.log_state_dict(state_dict, artifact_path) + state_dict_uri = mlflow.get_artifact_uri(artifact_path) + + loaded_state_dict = mlflow.pytorch.load_state_dict(state_dict_uri) + assert state_dict_equal(loaded_state_dict, state_dict) + model = get_sequential_model() + model.load_state_dict(loaded_state_dict) + np.testing.assert_array_almost_equal( + _predict(model, data), _predict(sequential_model, data), decimal=4, + )