Skip to content

Commit

Permalink
fix: latest model calling (#145)
Browse files Browse the repository at this point in the history
* fix: latest model calling

Signed-off-by: s0nicboOm <i.kushalbatra@gmail.com>

* fix: add custome exception

Signed-off-by: s0nicboOm <i.kushalbatra@gmail.com>

* add: add test cases

Signed-off-by: s0nicboOm <i.kushalbatra@gmail.com>

* fix: too long comments

Signed-off-by: s0nicboOm <i.kushalbatra@gmail.com>

* fix: log

Signed-off-by: s0nicboOm <i.kushalbatra@gmail.com>

* fix: exception to error

Signed-off-by: s0nicboOm <i.kushalbatra@gmail.com>

* fix: poetry version patch

Signed-off-by: s0nicboOm <i.kushalbatra@gmail.com>

---------

Signed-off-by: s0nicboOm <i.kushalbatra@gmail.com>
  • Loading branch information
s0nicboOm committed Mar 17, 2023
1 parent 6ee3446 commit b6f63ef
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 37 deletions.
88 changes: 55 additions & 33 deletions numalogic/registry/mlflow_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

import logging
from enum import Enum
from typing import Optional, Sequence
from typing import Optional, Sequence, Dict, Any, Tuple

import mlflow.pyfunc
import mlflow.pytorch
Expand All @@ -22,6 +22,7 @@
from mlflow.tracking import MlflowClient

from numalogic.registry import ArtifactManager, ArtifactData
from numalogic.tools.exceptions import ModelVersionError
from numalogic.tools.types import Artifact

_LOGGER = logging.getLogger()
Expand All @@ -39,7 +40,10 @@ class ModelStage(str, Enum):

class MLflowRegistry(ArtifactManager):
"""
Model saving and loading using MLFlow Registry.
Model saving and loading using MLFlow Registry. The parameter model_stage
determines what environment we are using. The old models are moved to
'Archived' state and the latest model comes to 'Staging' or 'Production'
depending on model_stage parameter.
More details here: https://mlflow.org/docs/latest/model-registry.html
Expand All @@ -49,6 +53,9 @@ class MLflowRegistry(ArtifactManager):
supported values include:
{"pytorch", "sklearn", "tensorflow", "pyfunc"}
models_to_retain: number of models to retain in the DB (default = 5)
model_stage: Staging environment from where to load the latest model from (mlflow )
supported values include:
{"Staging", "Production", "Archived"}(default = "Production")
Examples
--------
Expand All @@ -63,14 +70,15 @@ class MLflowRegistry(ArtifactManager):
>>> artifact_data = registry.load(skeys=["model"], dkeys=["AE"])
"""

__slots__ = ("client", "handler", "models_to_retain")
__slots__ = ("client", "handler", "models_to_retain", "model_stage")
_TRACKING_URI = None

def __new__(
cls,
tracking_uri: Optional[str],
artifact_type: str = "pytorch",
models_to_retain: int = 5,
model_stage: ModelStage = ModelStage.PRODUCTION,
*args,
**kwargs,
):
Expand All @@ -80,13 +88,18 @@ def __new__(
return instance

def __init__(
self, tracking_uri: str, artifact_type: str = "pytorch", models_to_retain: int = 5
self,
tracking_uri: str,
artifact_type: str = "pytorch",
models_to_retain: int = 5,
model_stage: str = ModelStage.PRODUCTION,
):
super().__init__(tracking_uri)
mlflow.set_tracking_uri(tracking_uri)
self.client = MlflowClient()
self.handler = self.mlflow_handler(artifact_type)
self.models_to_retain = models_to_retain
self.model_stage = model_stage

@staticmethod
def construct_key(skeys: Sequence[str], dkeys: Sequence[str]) -> str:
Expand Down Expand Up @@ -127,25 +140,19 @@ def load(
) -> Optional[ArtifactData]:
model_key = self.construct_key(skeys, dkeys)
try:
if latest:
model = self.handler.load_model(
model_uri=f"models:/{model_key}/{ModelStage.PRODUCTION}"
)
version_info = self.client.get_latest_versions(
model_key, stages=[ModelStage.PRODUCTION]
)[-1]
elif version is not None:
model = self.handler.load_model(model_uri=f"models:/{model_key}/{version}")
version_info = self.client.get_model_version(model_key, version)
if (latest and version) or (not latest and not version):
raise ValueError("Either One of 'latest' or 'version' needed in load method call")

elif latest:
version_info = self.client.get_latest_versions(model_key, stages=[self.model_stage])
if not version_info:
raise ModelVersionError("Model version missing for key = %s" % model_key)
version_info = version_info[-1]
else:
raise ValueError("One of 'latest' or 'version' needed in load method call")
_LOGGER.info("Successfully loaded model %s from Mlflow", model_key)

run_info = mlflow.get_run(version_info.run_id)
metadata = run_info.data.params or None
_LOGGER.info("Successfully loaded model metadata from Mlflow!")

version_info = self.client.get_model_version(model_key, version)
model, metadata = self.__load_artifacts(skeys, dkeys, version_info)
return ArtifactData(artifact=model, metadata=metadata, extras=dict(version_info))

except RestException as mlflow_err:
if ErrorCode.Value(mlflow_err.error_code) == RESOURCE_DOES_NOT_EXIST:
_LOGGER.info("Model not found with key: %s", model_key)
Expand All @@ -154,8 +161,15 @@ def load(
"Mlflow error when loading a model with key: %s: %r", model_key, mlflow_err
)
return None
except ModelVersionError as model_missing_err:
_LOGGER.error(
"No Model found found in %s ERROR: %r",
self.model_stage,
model_missing_err,
)
return None
except Exception as ex:
_LOGGER.exception("Error when loading a model with key: %s: %r", model_key, ex)
_LOGGER.exception("Unexpected error: %s", ex)
return None

def save(
Expand Down Expand Up @@ -226,25 +240,17 @@ def transition_stage(
"""
model_name = self.construct_key(skeys, dkeys)
try:
current_production = self.client.get_latest_versions(
name=model_name, stages=["Production"]
current_staging = self.client.get_latest_versions(
name=model_name, stages=[self.model_stage]
)
current_staging = self.client.get_latest_versions(name=model_name, stages=["Staging"])
latest = self.client.get_latest_versions(name=model_name, stages=["None"])

latest_model_data = self.client.transition_model_version_stage(
name=model_name,
version=str(latest[-1].version),
stage=ModelStage.PRODUCTION,
stage=self.model_stage,
)

if current_production:
self.client.transition_model_version_stage(
name=model_name,
version=str(current_production[-1].version),
stage=ModelStage.STAGE,
)

if current_staging:
self.client.transition_model_version_stage(
name=model_name,
Expand All @@ -271,3 +277,19 @@ def __delete_stale_models(self, skeys: Sequence[str], dkeys: Sequence[str]):
for stale_model in models_to_delete:
self.delete(skeys=skeys, dkeys=dkeys, version=stale_model.version)
_LOGGER.debug("Deleted stale model version : %s", stale_model.version)

def __load_artifacts(
self, skeys: Sequence[str], dkeys: Sequence[str], version_info: ModelVersion
) -> Tuple[Artifact, Dict[str, Any]]:
model_key = self.construct_key(skeys, dkeys)
model = self.handler.load_model(model_uri=f"models:/{model_key}/{version_info.version}")
_LOGGER.info("Successfully loaded model %s from Mlflow", model_key)

run_info = mlflow.get_run(version_info.run_id)
metadata = run_info.data.params or {}
_LOGGER.info(
"Successfully loaded model = %s with version %s Mlflow!",
model_key,
version_info.version,
)
return model, metadata
4 changes: 4 additions & 0 deletions numalogic/tools/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,7 @@ class InvalidDataShapeError(Exception):

class UnknownConfigArgsError(Exception):
pass


class ModelVersionError(Exception):
pass
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "numalogic"
version = "0.3.6"
version = "0.3.7"
description = "Collection of operational Machine Learning models and tools."
authors = ["Numalogic Developers"]
packages = [{ include = "numalogic" }]
Expand Down
37 changes: 34 additions & 3 deletions tests/registry/test_mlflow_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,13 @@
from mlflow import ActiveRun
from mlflow.exceptions import RestException
from mlflow.protos.databricks_pb2 import RESOURCE_DOES_NOT_EXIST, ErrorCode, RESOURCE_LIMIT_EXCEEDED
from mlflow.store.entities import PagedList
from sklearn.ensemble import RandomForestRegressor

from numalogic.models.autoencoder.variants import VanillaAE
from numalogic.registry import MLflowRegistry
from numalogic.registry.mlflow_registry import ModelStage
from numalogic.tools.exceptions import ModelVersionError
from tests.registry._mlflow_utils import (
model_sklearn,
create_model,
Expand Down Expand Up @@ -120,7 +123,7 @@ def test_load_model_when_pytorch_model_exist2(self):
artifact=model,
)
data = ml.load(skeys=skeys, dkeys=dkeys)
self.assertIsNone(data.metadata)
self.assertEqual(data.metadata, {})
self.assertIsInstance(data.artifact, VanillaAE)

@patch("mlflow.sklearn.log_model", mock_log_model_sklearn)
Expand All @@ -143,7 +146,7 @@ def test_load_model_when_sklearn_model_exist(self):
)
data = ml.load(skeys=skeys, dkeys=dkeys)
self.assertIsInstance(data.artifact, RandomForestRegressor)
self.assertIsNone(data.metadata)
self.assertEqual(data.metadata, {})

@patch("mlflow.pytorch.log_model", mock_log_model_pytorch())
@patch("mlflow.start_run", Mock(return_value=ActiveRun(return_empty_rundata())))
Expand All @@ -165,7 +168,35 @@ def test_load_model_with_version(self):
)
data = ml.load(skeys=skeys, dkeys=dkeys, version="5", latest=False)
self.assertIsInstance(data.artifact, VanillaAE)
self.assertIsNone(data.metadata)
self.assertEqual(data.metadata, {})

@patch("mlflow.tracking.MlflowClient.search_model_versions", mock_list_of_model_version2)
@patch("mlflow.tracking.MlflowClient.transition_model_version_stage", mock_transition_stage)
@patch(
"mlflow.tracking.MlflowClient.get_latest_versions",
Mock(return_value=PagedList(items=[], token=None)),
)
@patch("mlflow.pytorch.load_model", Mock(return_value=VanillaAE(10)))
@patch("mlflow.tracking.MlflowClient.get_run", Mock(return_value=return_empty_rundata()))
def test_staging_model_load_error(self):
ml = MLflowRegistry(TRACKING_URI, model_stage=ModelStage.STAGE)
skeys = self.skeys
dkeys = self.dkeys
ml.load(skeys=skeys, dkeys=dkeys)
self.assertRaises(ModelVersionError)

@patch("mlflow.tracking.MlflowClient.search_model_versions", mock_list_of_model_version2)
@patch("mlflow.tracking.MlflowClient.transition_model_version_stage", mock_transition_stage)
@patch("mlflow.tracking.MlflowClient.get_latest_versions", mock_get_model_version())
@patch("mlflow.pytorch.load_model", Mock(return_value=VanillaAE(10)))
@patch("mlflow.tracking.MlflowClient.get_run", Mock(return_value=return_empty_rundata()))
def test_both_version_latest_model_with_version(self):
ml = MLflowRegistry(TRACKING_URI)
skeys = self.skeys
dkeys = self.dkeys
with self.assertLogs(level="ERROR") as log:
ml.load(skeys=skeys, dkeys=dkeys, latest=False)
self.assertTrue(log.output)

@patch("mlflow.pyfunc.load_model", Mock(side_effect=RuntimeError))
def test_load_model_when_no_model_01(self):
Expand Down

0 comments on commit b6f63ef

Please sign in to comment.