From b6f63efc11fce355367834e707fb255549c06d39 Mon Sep 17 00:00:00 2001 From: Kushal Batra <34571348+s0nicboOm@users.noreply.github.com> Date: Fri, 17 Mar 2023 14:47:13 -0700 Subject: [PATCH] fix: latest model calling (#145) * fix: latest model calling Signed-off-by: s0nicboOm * fix: add custome exception Signed-off-by: s0nicboOm * add: add test cases Signed-off-by: s0nicboOm * fix: too long comments Signed-off-by: s0nicboOm * fix: log Signed-off-by: s0nicboOm * fix: exception to error Signed-off-by: s0nicboOm * fix: poetry version patch Signed-off-by: s0nicboOm --------- Signed-off-by: s0nicboOm --- numalogic/registry/mlflow_registry.py | 88 ++++++++++++++++---------- numalogic/tools/exceptions.py | 4 ++ pyproject.toml | 2 +- tests/registry/test_mlflow_registry.py | 37 ++++++++++- 4 files changed, 94 insertions(+), 37 deletions(-) diff --git a/numalogic/registry/mlflow_registry.py b/numalogic/registry/mlflow_registry.py index c431e029..dd531a44 100644 --- a/numalogic/registry/mlflow_registry.py +++ b/numalogic/registry/mlflow_registry.py @@ -12,7 +12,7 @@ import logging from enum import Enum -from typing import Optional, Sequence +from typing import Optional, Sequence, Dict, Any, Tuple import mlflow.pyfunc import mlflow.pytorch @@ -22,6 +22,7 @@ from mlflow.tracking import MlflowClient from numalogic.registry import ArtifactManager, ArtifactData +from numalogic.tools.exceptions import ModelVersionError from numalogic.tools.types import Artifact _LOGGER = logging.getLogger() @@ -39,7 +40,10 @@ class ModelStage(str, Enum): class MLflowRegistry(ArtifactManager): """ - Model saving and loading using MLFlow Registry. + Model saving and loading using MLFlow Registry. The parameter model_stage + determines what environment we are using. The old models are moved to + 'Archived' state and the latest model comes to 'Staging' or 'Production' + depending on model_stage parameter. More details here: https://mlflow.org/docs/latest/model-registry.html @@ -49,6 +53,9 @@ class MLflowRegistry(ArtifactManager): supported values include: {"pytorch", "sklearn", "tensorflow", "pyfunc"} models_to_retain: number of models to retain in the DB (default = 5) + model_stage: Staging environment from where to load the latest model from (mlflow ) + supported values include: + {"Staging", "Production", "Archived"}(default = "Production") Examples -------- @@ -63,7 +70,7 @@ class MLflowRegistry(ArtifactManager): >>> artifact_data = registry.load(skeys=["model"], dkeys=["AE"]) """ - __slots__ = ("client", "handler", "models_to_retain") + __slots__ = ("client", "handler", "models_to_retain", "model_stage") _TRACKING_URI = None def __new__( @@ -71,6 +78,7 @@ def __new__( tracking_uri: Optional[str], artifact_type: str = "pytorch", models_to_retain: int = 5, + model_stage: ModelStage = ModelStage.PRODUCTION, *args, **kwargs, ): @@ -80,13 +88,18 @@ def __new__( return instance def __init__( - self, tracking_uri: str, artifact_type: str = "pytorch", models_to_retain: int = 5 + self, + tracking_uri: str, + artifact_type: str = "pytorch", + models_to_retain: int = 5, + model_stage: str = ModelStage.PRODUCTION, ): super().__init__(tracking_uri) mlflow.set_tracking_uri(tracking_uri) self.client = MlflowClient() self.handler = self.mlflow_handler(artifact_type) self.models_to_retain = models_to_retain + self.model_stage = model_stage @staticmethod def construct_key(skeys: Sequence[str], dkeys: Sequence[str]) -> str: @@ -127,25 +140,19 @@ def load( ) -> Optional[ArtifactData]: model_key = self.construct_key(skeys, dkeys) try: - if latest: - model = self.handler.load_model( - model_uri=f"models:/{model_key}/{ModelStage.PRODUCTION}" - ) - version_info = self.client.get_latest_versions( - model_key, stages=[ModelStage.PRODUCTION] - )[-1] - elif version is not None: - model = self.handler.load_model(model_uri=f"models:/{model_key}/{version}") - version_info = self.client.get_model_version(model_key, version) + if (latest and version) or (not latest and not version): + raise ValueError("Either One of 'latest' or 'version' needed in load method call") + + elif latest: + version_info = self.client.get_latest_versions(model_key, stages=[self.model_stage]) + if not version_info: + raise ModelVersionError("Model version missing for key = %s" % model_key) + version_info = version_info[-1] else: - raise ValueError("One of 'latest' or 'version' needed in load method call") - _LOGGER.info("Successfully loaded model %s from Mlflow", model_key) - - run_info = mlflow.get_run(version_info.run_id) - metadata = run_info.data.params or None - _LOGGER.info("Successfully loaded model metadata from Mlflow!") - + version_info = self.client.get_model_version(model_key, version) + model, metadata = self.__load_artifacts(skeys, dkeys, version_info) return ArtifactData(artifact=model, metadata=metadata, extras=dict(version_info)) + except RestException as mlflow_err: if ErrorCode.Value(mlflow_err.error_code) == RESOURCE_DOES_NOT_EXIST: _LOGGER.info("Model not found with key: %s", model_key) @@ -154,8 +161,15 @@ def load( "Mlflow error when loading a model with key: %s: %r", model_key, mlflow_err ) return None + except ModelVersionError as model_missing_err: + _LOGGER.error( + "No Model found found in %s ERROR: %r", + self.model_stage, + model_missing_err, + ) + return None except Exception as ex: - _LOGGER.exception("Error when loading a model with key: %s: %r", model_key, ex) + _LOGGER.exception("Unexpected error: %s", ex) return None def save( @@ -226,25 +240,17 @@ def transition_stage( """ model_name = self.construct_key(skeys, dkeys) try: - current_production = self.client.get_latest_versions( - name=model_name, stages=["Production"] + current_staging = self.client.get_latest_versions( + name=model_name, stages=[self.model_stage] ) - current_staging = self.client.get_latest_versions(name=model_name, stages=["Staging"]) latest = self.client.get_latest_versions(name=model_name, stages=["None"]) latest_model_data = self.client.transition_model_version_stage( name=model_name, version=str(latest[-1].version), - stage=ModelStage.PRODUCTION, + stage=self.model_stage, ) - if current_production: - self.client.transition_model_version_stage( - name=model_name, - version=str(current_production[-1].version), - stage=ModelStage.STAGE, - ) - if current_staging: self.client.transition_model_version_stage( name=model_name, @@ -271,3 +277,19 @@ def __delete_stale_models(self, skeys: Sequence[str], dkeys: Sequence[str]): for stale_model in models_to_delete: self.delete(skeys=skeys, dkeys=dkeys, version=stale_model.version) _LOGGER.debug("Deleted stale model version : %s", stale_model.version) + + def __load_artifacts( + self, skeys: Sequence[str], dkeys: Sequence[str], version_info: ModelVersion + ) -> Tuple[Artifact, Dict[str, Any]]: + model_key = self.construct_key(skeys, dkeys) + model = self.handler.load_model(model_uri=f"models:/{model_key}/{version_info.version}") + _LOGGER.info("Successfully loaded model %s from Mlflow", model_key) + + run_info = mlflow.get_run(version_info.run_id) + metadata = run_info.data.params or {} + _LOGGER.info( + "Successfully loaded model = %s with version %s Mlflow!", + model_key, + version_info.version, + ) + return model, metadata diff --git a/numalogic/tools/exceptions.py b/numalogic/tools/exceptions.py index 10462758..c338f580 100644 --- a/numalogic/tools/exceptions.py +++ b/numalogic/tools/exceptions.py @@ -36,3 +36,7 @@ class InvalidDataShapeError(Exception): class UnknownConfigArgsError(Exception): pass + + +class ModelVersionError(Exception): + pass diff --git a/pyproject.toml b/pyproject.toml index 0a34a23f..ce0ca14e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "numalogic" -version = "0.3.6" +version = "0.3.7" description = "Collection of operational Machine Learning models and tools." authors = ["Numalogic Developers"] packages = [{ include = "numalogic" }] diff --git a/tests/registry/test_mlflow_registry.py b/tests/registry/test_mlflow_registry.py index 63dbc795..92442f86 100644 --- a/tests/registry/test_mlflow_registry.py +++ b/tests/registry/test_mlflow_registry.py @@ -5,10 +5,13 @@ from mlflow import ActiveRun from mlflow.exceptions import RestException from mlflow.protos.databricks_pb2 import RESOURCE_DOES_NOT_EXIST, ErrorCode, RESOURCE_LIMIT_EXCEEDED +from mlflow.store.entities import PagedList from sklearn.ensemble import RandomForestRegressor from numalogic.models.autoencoder.variants import VanillaAE from numalogic.registry import MLflowRegistry +from numalogic.registry.mlflow_registry import ModelStage +from numalogic.tools.exceptions import ModelVersionError from tests.registry._mlflow_utils import ( model_sklearn, create_model, @@ -120,7 +123,7 @@ def test_load_model_when_pytorch_model_exist2(self): artifact=model, ) data = ml.load(skeys=skeys, dkeys=dkeys) - self.assertIsNone(data.metadata) + self.assertEqual(data.metadata, {}) self.assertIsInstance(data.artifact, VanillaAE) @patch("mlflow.sklearn.log_model", mock_log_model_sklearn) @@ -143,7 +146,7 @@ def test_load_model_when_sklearn_model_exist(self): ) data = ml.load(skeys=skeys, dkeys=dkeys) self.assertIsInstance(data.artifact, RandomForestRegressor) - self.assertIsNone(data.metadata) + self.assertEqual(data.metadata, {}) @patch("mlflow.pytorch.log_model", mock_log_model_pytorch()) @patch("mlflow.start_run", Mock(return_value=ActiveRun(return_empty_rundata()))) @@ -165,7 +168,35 @@ def test_load_model_with_version(self): ) data = ml.load(skeys=skeys, dkeys=dkeys, version="5", latest=False) self.assertIsInstance(data.artifact, VanillaAE) - self.assertIsNone(data.metadata) + self.assertEqual(data.metadata, {}) + + @patch("mlflow.tracking.MlflowClient.search_model_versions", mock_list_of_model_version2) + @patch("mlflow.tracking.MlflowClient.transition_model_version_stage", mock_transition_stage) + @patch( + "mlflow.tracking.MlflowClient.get_latest_versions", + Mock(return_value=PagedList(items=[], token=None)), + ) + @patch("mlflow.pytorch.load_model", Mock(return_value=VanillaAE(10))) + @patch("mlflow.tracking.MlflowClient.get_run", Mock(return_value=return_empty_rundata())) + def test_staging_model_load_error(self): + ml = MLflowRegistry(TRACKING_URI, model_stage=ModelStage.STAGE) + skeys = self.skeys + dkeys = self.dkeys + ml.load(skeys=skeys, dkeys=dkeys) + self.assertRaises(ModelVersionError) + + @patch("mlflow.tracking.MlflowClient.search_model_versions", mock_list_of_model_version2) + @patch("mlflow.tracking.MlflowClient.transition_model_version_stage", mock_transition_stage) + @patch("mlflow.tracking.MlflowClient.get_latest_versions", mock_get_model_version()) + @patch("mlflow.pytorch.load_model", Mock(return_value=VanillaAE(10))) + @patch("mlflow.tracking.MlflowClient.get_run", Mock(return_value=return_empty_rundata())) + def test_both_version_latest_model_with_version(self): + ml = MLflowRegistry(TRACKING_URI) + skeys = self.skeys + dkeys = self.dkeys + with self.assertLogs(level="ERROR") as log: + ml.load(skeys=skeys, dkeys=dkeys, latest=False) + self.assertTrue(log.output) @patch("mlflow.pyfunc.load_model", Mock(side_effect=RuntimeError)) def test_load_model_when_no_model_01(self):