diff --git a/numalogic/registry/artifact.py b/numalogic/registry/artifact.py index dfd4ee6e..9e255b93 100644 --- a/numalogic/registry/artifact.py +++ b/numalogic/registry/artifact.py @@ -1,5 +1,5 @@ from abc import ABCMeta, abstractmethod -from typing import Sequence, Any +from typing import Sequence, Any, Union, Dict from numalogic.tools.types import Artifact @@ -34,7 +34,7 @@ def save( skeys: Sequence[str], dkeys: Sequence[str], primary_artifact: Artifact, - secondary_artifact: Artifact = None, + secondary_artifacts: Union[Sequence[Artifact], Dict[str, Artifact], None] = None, **metadata ) -> Any: r""" @@ -43,7 +43,7 @@ def save( skeys: static key fields as list/tuple of strings dkeys: dynamic key fields as list/tuple of strings primary_artifact: primary artifact to be saved - secondary_artifact: secondary artifact to be saved + secondary_artifacts: secondary artifact to be saved metadata: additional metadata surrounding the artifact that needs to be saved """ pass diff --git a/numalogic/registry/mlflow_registry.py b/numalogic/registry/mlflow_registry.py index e741b4bc..c3e0b0cc 100644 --- a/numalogic/registry/mlflow_registry.py +++ b/numalogic/registry/mlflow_registry.py @@ -37,6 +37,7 @@ class MLflowRegistrar(ArtifactManager): artifact_type: the type of primary artifact to use supported values include: {"pytorch", "sklearn", "tensorflow", "pyfunc"} + models_to_retain: number of models to retain in the DB (default = 5) Examples -------- @@ -54,11 +55,14 @@ class MLflowRegistrar(ArtifactManager): >>> data = ml.load(skeys=["model"],dkeys=["AE"]) """ - def __init__(self, tracking_uri: str, artifact_type: str = "pytorch"): + def __init__( + self, tracking_uri: str, artifact_type: str = "pytorch", models_to_retain: int = 5 + ): super().__init__(tracking_uri) mlflow.set_tracking_uri(tracking_uri) self.client = MlflowClient() self.handler = self.mlflow_handler(artifact_type) + self.models_to_retain = models_to_retain @staticmethod def __as_dict( @@ -194,7 +198,7 @@ def save( data = codecs.encode(pickle.dumps(metadata), "base64").decode() mlflow.log_param(key="metadata", value=data) mlflow.log_param(key="model_key", value=model_key) - model_version = self.transition_stage(model_name=model_key) + model_version = self.transition_stage(skeys=skeys, dkeys=dkeys) _LOGGER.info("Successfully inserted model %s to Mlflow", model_key) return model_version except Exception as ex: @@ -219,17 +223,20 @@ def delete(self, skeys: Sequence[str], dkeys: Sequence[str], version: str) -> No except Exception as ex: _LOGGER.exception("Error when deleting a model with key: %s: %r", model_key, ex) - def transition_stage(self, model_name: str) -> Optional[ModelVersion]: + def transition_stage( + self, skeys: Sequence[str], dkeys: Sequence[str] + ) -> Optional[ModelVersion]: """ Changes stage information for the given model. Sets new model to "Production". The old production model is set to "Staging" and the rest model versions are set to "Archived". Args: - model_name: model name for which we are updating the stage information. - + skeys: static key fields as list/tuple of strings + dkeys: dynamic key fields as list/tuple of strings Returns: mlflow ModelVersion instance """ + model_name = self.construct_key(skeys, dkeys) try: version = int(self.get_version(model_name=model_name)) latest_model_data = self.client.transition_model_version_stage( @@ -249,6 +256,12 @@ def transition_stage(self, model_name: str) -> Optional[ModelVersion]: version=str(version - 2), stage=ModelStage.ARCHIVE, ) + + # only keep "models_to_retain" number of models. + list_model_versions = list(self.client.search_model_versions(f"name='{model_name}'")) + models_to_delete = list_model_versions[: -self.models_to_retain] + for stale_model in models_to_delete: + self.delete(skeys=skeys, dkeys=dkeys, version=stale_model.version) _LOGGER.info("Successfully transitioned model to Production stage") return latest_model_data except Exception as ex: @@ -266,7 +279,7 @@ def get_version(self, model_name: str) -> Optional[ModelVersion]: version from mlflow ModelVersion instance """ try: - return self.client.get_latest_versions(model_name, stages=[])[-1].version + return self.client.get_latest_versions(name=model_name, stages=[])[-1].version except RestException as ex: _LOGGER.error("Error when getting model version: %r", ex) return None diff --git a/numalogic/tests/registry/_mlflow_utils.py b/numalogic/tests/registry/_mlflow_utils.py index eb2b3644..7fe24c97 100644 --- a/numalogic/tests/registry/_mlflow_utils.py +++ b/numalogic/tests/registry/_mlflow_utils.py @@ -6,6 +6,7 @@ from mlflow.entities import RunData, RunInfo, Run from mlflow.entities.model_registry import ModelVersion from mlflow.models.model import ModelInfo +from mlflow.store.entities import PagedList from sklearn.ensemble import RandomForestRegressor from sklearn.preprocessing import StandardScaler from torch import tensor @@ -149,7 +150,7 @@ def mock_transition_stage(*_, **__): status_message="", tags={}, user_id="", - version="5", + version="2", ) @@ -173,6 +174,77 @@ def mock_get_model_version(*_, **__): ] +def mock_list_of_model_version(*_, **__): + model_list = [ + ModelVersion( + creation_timestamp=1653402941169, + current_stage="Production", + description="", + last_updated_timestamp=1653402941191, + name="testtest:error", + run_id="6e85c26e6e8b49fdb493807d5a527a2c", + run_link="", + source="mlflow-artifacts:/0/6e85c26e6e8b49fdb493807d5a527a2c/artifacts/model", + status="READY", + status_message="", + tags={}, + user_id="", + version="8", + ), + ModelVersion( + creation_timestamp=1653402941169, + current_stage="Production", + description="", + last_updated_timestamp=1653402941191, + name="testtest:error", + run_id="6e85c26e6e8b49fdb493807d5a527a2c", + run_link="", + source="mlflow-artifacts:/0/6e85c26e6e8b49fdb493807d5a527a2c/artifacts/model", + status="READY", + status_message="", + tags={}, + user_id="", + version="9", + ), + ModelVersion( + creation_timestamp=1653402941169, + current_stage="Production", + description="", + last_updated_timestamp=1653402941191, + name="testtest:error", + run_id="6e85c26e6e8b49fdb493807d5a527a2c", + run_link="", + source="mlflow-artifacts:/0/6e85c26e6e8b49fdb493807d5a527a2c/artifacts/model", + status="READY", + status_message="", + tags={}, + user_id="", + version="10", + ), + ModelVersion( + creation_timestamp=1653402941169, + current_stage="Production", + description="", + last_updated_timestamp=1653402941191, + name="testtest:error", + run_id="6e85c26e6e8b49fdb493807d5a527a2c", + run_link="", + source="mlflow-artifacts:/0/6e85c26e6e8b49fdb493807d5a527a2c/artifacts/model", + status="READY", + status_message="", + tags={}, + user_id="", + version="11", + ), + ] + + return PagedList(items=model_list, token=None) + + +def mock_list_of_model_version2(*_, **__): + return PagedList(items=mock_get_model_version(), token=None) + + def return_scaler(): scaler = StandardScaler() data = [[0, 0], [0, 0], [1, 1], [1, 1]] diff --git a/numalogic/tests/registry/test_mlflow_registry.py b/numalogic/tests/registry/test_mlflow_registry.py index eb26a661..2c797ac7 100644 --- a/numalogic/tests/registry/test_mlflow_registry.py +++ b/numalogic/tests/registry/test_mlflow_registry.py @@ -22,6 +22,8 @@ return_pytorch_rundata_dict, return_empty_rundata, return_pytorch_rundata_list, + mock_list_of_model_version, + mock_list_of_model_version2, ) TRACKING_URI = "http://0.0.0.0:5009" @@ -50,9 +52,10 @@ def test_construct_key(self): @patch("mlflow.log_param", mock_log_state_dict) @patch("mlflow.tracking.MlflowClient.transition_model_version_stage", mock_transition_stage) @patch("mlflow.tracking.MlflowClient.get_latest_versions", mock_get_model_version) + @patch("mlflow.tracking.MlflowClient.search_model_versions", mock_list_of_model_version) def test_insert_model(self): model = self.model - ml = MLflowRegistrar(TRACKING_URI, artifact_type="pytorch") + ml = MLflowRegistrar(TRACKING_URI, artifact_type="pytorch", models_to_retain=2) skeys = ["model_", "nnet"] dkeys = ["error1"] @@ -62,6 +65,7 @@ def test_insert_model(self): primary_artifact=model, secondary_artifacts=[make_pipeline(return_scaler)], artifact_state_dict=model.state_dict(), + models_to_retain=2, ) mock_status = "READY" self.assertEqual(mock_status, status.status) @@ -69,6 +73,7 @@ def test_insert_model(self): @patch("mlflow.sklearn.log_model", mock_log_model_sklearn) @patch("mlflow.tracking.MlflowClient.transition_model_version_stage", mock_transition_stage) @patch("mlflow.tracking.MlflowClient.get_latest_versions", mock_get_model_version) + @patch("mlflow.tracking.MlflowClient.search_model_versions", mock_list_of_model_version2) def test_insert_model_sklearn(self): model = self.model_sklearn ml = MLflowRegistrar(TRACKING_URI, artifact_type="sklearn") @@ -104,7 +109,6 @@ def test_select_model_when_pytorch_model_exist1(self): ) data = ml.load(skeys=skeys, dkeys=dkeys) self.assertIsInstance(data["primary_artifact"], VanillaAE) - print(data["secondary_artifacts"]) self.assertIsInstance(data["secondary_artifacts"]["preproc"], Pipeline) self.assertIsInstance(data["secondary_artifacts"]["postproc"], Pipeline) @@ -130,12 +134,12 @@ def test_select_model_when_pytorch_model_exist2(self): ) data = ml.load(skeys=skeys, dkeys=dkeys) self.assertIsInstance(data["primary_artifact"], VanillaAE) - print(data["secondary_artifacts"]) self.assertIsInstance(data["secondary_artifacts"], list) @patch("mlflow.sklearn.log_model", mock_log_model_sklearn) @patch("mlflow.tracking.MlflowClient.transition_model_version_stage", mock_transition_stage) @patch("mlflow.tracking.MlflowClient.get_latest_versions", mock_get_model_version) + @patch("mlflow.tracking.MlflowClient.search_model_versions", mock_list_of_model_version2) @patch("mlflow.sklearn.load_model", Mock(return_value=RandomForestRegressor())) @patch("mlflow.tracking.MlflowClient.get_run", Mock(return_value=return_empty_rundata())) def test_select_model_when_sklearn_model_exist(self): @@ -172,10 +176,6 @@ def test_select_model_with_version(self): self.assertIsInstance(data["primary_artifact"], VanillaAE) self.assertEqual(data["metadata"], None) - @patch("mlflow.pyfunc.log_model", mock_log_model_pytorch) - @patch("mlflow.log_param", mock_log_state_dict) - @patch("mlflow.tracking.MlflowClient.transition_model_version_stage", mock_transition_stage) - @patch("mlflow.tracking.MlflowClient.get_latest_versions", mock_get_model_version) @patch("mlflow.pyfunc.load_model", Mock(side_effect=RuntimeError)) def test_select_model_when_no_model_01(self): fake_skeys = ["Fakemodel_"] @@ -185,10 +185,6 @@ def test_select_model_when_no_model_01(self): ml.load(skeys=fake_skeys, dkeys=fake_dkeys) self.assertTrue(log.output) - @patch("mlflow.tensorflow.log_model", mock_log_model_pytorch) - @patch("mlflow.log_param", mock_log_state_dict) - @patch("mlflow.tracking.MlflowClient.transition_model_version_stage", mock_transition_stage) - @patch("mlflow.tracking.MlflowClient.get_latest_versions", mock_get_model_version) @patch("mlflow.tensorflow.load_model", Mock(side_effect=RuntimeError)) def test_select_model_when_no_model_02(self): fake_skeys = ["Fakemodel_"] @@ -198,6 +194,19 @@ def test_select_model_when_no_model_02(self): ml.load(skeys=fake_skeys, dkeys=fake_dkeys) self.assertTrue(log.output) + @patch("mlflow.tracking.MlflowClient.get_latest_versions", mock_get_model_version) + @patch( + "mlflow.tracking.MlflowClient.transition_model_version_stage", + Mock(side_effect=RuntimeError), + ) + def test_transition_stage_fail(self): + fake_skeys = ["Fakemodel_"] + fake_dkeys = ["error"] + ml = MLflowRegistrar(TRACKING_URI, artifact_type="tensorflow") + with self.assertLogs(level="ERROR") as log: + ml.transition_stage(fake_skeys, fake_dkeys) + self.assertTrue(log.output) + def test_no_implementation(self): with self.assertRaises(NotImplementedError): MLflowRegistrar(TRACKING_URI, artifact_type="some_random") @@ -206,6 +215,7 @@ def test_no_implementation(self): @patch("mlflow.log_param", mock_log_state_dict) @patch("mlflow.tracking.MlflowClient.transition_model_version_stage", mock_transition_stage) @patch("mlflow.tracking.MlflowClient.get_latest_versions", mock_get_model_version) + @patch("mlflow.tracking.MlflowClient.search_model_versions", mock_list_of_model_version2) @patch("mlflow.tracking.MlflowClient.delete_model_version", None) @patch("mlflow.pytorch.load_model", Mock(side_effect=RuntimeError)) def test_delete_model_when_model_exist(self): diff --git a/pyproject.toml b/pyproject.toml index f368dbc5..459ae290 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "numalogic" -version = "0.2.2" +version = "0.2.3" description = "Collection of operational Machine Learning models and tools." authors = ["Numalogic Developers"] packages = [{ include = "numalogic" }]