Skip to content

Commit

Permalink
feat: adding feature for retaining fixed number of stale model (#13)
Browse files Browse the repository at this point in the history
Signed-off-by: s0nicboOm <i.kushalbatra@gmail.com>
  • Loading branch information
s0nicboOm authored Aug 11, 2022
1 parent 47b2519 commit 040584f
Show file tree
Hide file tree
Showing 5 changed files with 117 additions and 22 deletions.
6 changes: 3 additions & 3 deletions numalogic/registry/artifact.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABCMeta, abstractmethod
from typing import Sequence, Any
from typing import Sequence, Any, Union, Dict

from numalogic.tools.types import Artifact

Expand Down Expand Up @@ -34,7 +34,7 @@ def save(
skeys: Sequence[str],
dkeys: Sequence[str],
primary_artifact: Artifact,
secondary_artifact: Artifact = None,
secondary_artifacts: Union[Sequence[Artifact], Dict[str, Artifact], None] = None,
**metadata
) -> Any:
r"""
Expand All @@ -43,7 +43,7 @@ def save(
skeys: static key fields as list/tuple of strings
dkeys: dynamic key fields as list/tuple of strings
primary_artifact: primary artifact to be saved
secondary_artifact: secondary artifact to be saved
secondary_artifacts: secondary artifact to be saved
metadata: additional metadata surrounding the artifact that needs to be saved
"""
pass
Expand Down
25 changes: 19 additions & 6 deletions numalogic/registry/mlflow_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class MLflowRegistrar(ArtifactManager):
artifact_type: the type of primary artifact to use
supported values include:
{"pytorch", "sklearn", "tensorflow", "pyfunc"}
models_to_retain: number of models to retain in the DB (default = 5)
Examples
--------
Expand All @@ -54,11 +55,14 @@ class MLflowRegistrar(ArtifactManager):
>>> data = ml.load(skeys=["model"],dkeys=["AE"])
"""

def __init__(self, tracking_uri: str, artifact_type: str = "pytorch"):
def __init__(
self, tracking_uri: str, artifact_type: str = "pytorch", models_to_retain: int = 5
):
super().__init__(tracking_uri)
mlflow.set_tracking_uri(tracking_uri)
self.client = MlflowClient()
self.handler = self.mlflow_handler(artifact_type)
self.models_to_retain = models_to_retain

@staticmethod
def __as_dict(
Expand Down Expand Up @@ -194,7 +198,7 @@ def save(
data = codecs.encode(pickle.dumps(metadata), "base64").decode()
mlflow.log_param(key="metadata", value=data)
mlflow.log_param(key="model_key", value=model_key)
model_version = self.transition_stage(model_name=model_key)
model_version = self.transition_stage(skeys=skeys, dkeys=dkeys)
_LOGGER.info("Successfully inserted model %s to Mlflow", model_key)
return model_version
except Exception as ex:
Expand All @@ -219,17 +223,20 @@ def delete(self, skeys: Sequence[str], dkeys: Sequence[str], version: str) -> No
except Exception as ex:
_LOGGER.exception("Error when deleting a model with key: %s: %r", model_key, ex)

def transition_stage(self, model_name: str) -> Optional[ModelVersion]:
def transition_stage(
self, skeys: Sequence[str], dkeys: Sequence[str]
) -> Optional[ModelVersion]:
"""
Changes stage information for the given model. Sets new model to "Production". The old
production model is set to "Staging" and the rest model versions are set to "Archived".
Args:
model_name: model name for which we are updating the stage information.
skeys: static key fields as list/tuple of strings
dkeys: dynamic key fields as list/tuple of strings
Returns:
mlflow ModelVersion instance
"""
model_name = self.construct_key(skeys, dkeys)
try:
version = int(self.get_version(model_name=model_name))
latest_model_data = self.client.transition_model_version_stage(
Expand All @@ -249,6 +256,12 @@ def transition_stage(self, model_name: str) -> Optional[ModelVersion]:
version=str(version - 2),
stage=ModelStage.ARCHIVE,
)

# only keep "models_to_retain" number of models.
list_model_versions = list(self.client.search_model_versions(f"name='{model_name}'"))
models_to_delete = list_model_versions[: -self.models_to_retain]
for stale_model in models_to_delete:
self.delete(skeys=skeys, dkeys=dkeys, version=stale_model.version)
_LOGGER.info("Successfully transitioned model to Production stage")
return latest_model_data
except Exception as ex:
Expand All @@ -266,7 +279,7 @@ def get_version(self, model_name: str) -> Optional[ModelVersion]:
version from mlflow ModelVersion instance
"""
try:
return self.client.get_latest_versions(model_name, stages=[])[-1].version
return self.client.get_latest_versions(name=model_name, stages=[])[-1].version
except RestException as ex:
_LOGGER.error("Error when getting model version: %r", ex)
return None
74 changes: 73 additions & 1 deletion numalogic/tests/registry/_mlflow_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from mlflow.entities import RunData, RunInfo, Run
from mlflow.entities.model_registry import ModelVersion
from mlflow.models.model import ModelInfo
from mlflow.store.entities import PagedList
from sklearn.ensemble import RandomForestRegressor
from sklearn.preprocessing import StandardScaler
from torch import tensor
Expand Down Expand Up @@ -149,7 +150,7 @@ def mock_transition_stage(*_, **__):
status_message="",
tags={},
user_id="",
version="5",
version="2",
)


Expand All @@ -173,6 +174,77 @@ def mock_get_model_version(*_, **__):
]


def mock_list_of_model_version(*_, **__):
model_list = [
ModelVersion(
creation_timestamp=1653402941169,
current_stage="Production",
description="",
last_updated_timestamp=1653402941191,
name="testtest:error",
run_id="6e85c26e6e8b49fdb493807d5a527a2c",
run_link="",
source="mlflow-artifacts:/0/6e85c26e6e8b49fdb493807d5a527a2c/artifacts/model",
status="READY",
status_message="",
tags={},
user_id="",
version="8",
),
ModelVersion(
creation_timestamp=1653402941169,
current_stage="Production",
description="",
last_updated_timestamp=1653402941191,
name="testtest:error",
run_id="6e85c26e6e8b49fdb493807d5a527a2c",
run_link="",
source="mlflow-artifacts:/0/6e85c26e6e8b49fdb493807d5a527a2c/artifacts/model",
status="READY",
status_message="",
tags={},
user_id="",
version="9",
),
ModelVersion(
creation_timestamp=1653402941169,
current_stage="Production",
description="",
last_updated_timestamp=1653402941191,
name="testtest:error",
run_id="6e85c26e6e8b49fdb493807d5a527a2c",
run_link="",
source="mlflow-artifacts:/0/6e85c26e6e8b49fdb493807d5a527a2c/artifacts/model",
status="READY",
status_message="",
tags={},
user_id="",
version="10",
),
ModelVersion(
creation_timestamp=1653402941169,
current_stage="Production",
description="",
last_updated_timestamp=1653402941191,
name="testtest:error",
run_id="6e85c26e6e8b49fdb493807d5a527a2c",
run_link="",
source="mlflow-artifacts:/0/6e85c26e6e8b49fdb493807d5a527a2c/artifacts/model",
status="READY",
status_message="",
tags={},
user_id="",
version="11",
),
]

return PagedList(items=model_list, token=None)


def mock_list_of_model_version2(*_, **__):
return PagedList(items=mock_get_model_version(), token=None)


def return_scaler():
scaler = StandardScaler()
data = [[0, 0], [0, 0], [1, 1], [1, 1]]
Expand Down
32 changes: 21 additions & 11 deletions numalogic/tests/registry/test_mlflow_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
return_pytorch_rundata_dict,
return_empty_rundata,
return_pytorch_rundata_list,
mock_list_of_model_version,
mock_list_of_model_version2,
)

TRACKING_URI = "http://0.0.0.0:5009"
Expand Down Expand Up @@ -50,9 +52,10 @@ def test_construct_key(self):
@patch("mlflow.log_param", mock_log_state_dict)
@patch("mlflow.tracking.MlflowClient.transition_model_version_stage", mock_transition_stage)
@patch("mlflow.tracking.MlflowClient.get_latest_versions", mock_get_model_version)
@patch("mlflow.tracking.MlflowClient.search_model_versions", mock_list_of_model_version)
def test_insert_model(self):
model = self.model
ml = MLflowRegistrar(TRACKING_URI, artifact_type="pytorch")
ml = MLflowRegistrar(TRACKING_URI, artifact_type="pytorch", models_to_retain=2)
skeys = ["model_", "nnet"]
dkeys = ["error1"]

Expand All @@ -62,13 +65,15 @@ def test_insert_model(self):
primary_artifact=model,
secondary_artifacts=[make_pipeline(return_scaler)],
artifact_state_dict=model.state_dict(),
models_to_retain=2,
)
mock_status = "READY"
self.assertEqual(mock_status, status.status)

@patch("mlflow.sklearn.log_model", mock_log_model_sklearn)
@patch("mlflow.tracking.MlflowClient.transition_model_version_stage", mock_transition_stage)
@patch("mlflow.tracking.MlflowClient.get_latest_versions", mock_get_model_version)
@patch("mlflow.tracking.MlflowClient.search_model_versions", mock_list_of_model_version2)
def test_insert_model_sklearn(self):
model = self.model_sklearn
ml = MLflowRegistrar(TRACKING_URI, artifact_type="sklearn")
Expand Down Expand Up @@ -104,7 +109,6 @@ def test_select_model_when_pytorch_model_exist1(self):
)
data = ml.load(skeys=skeys, dkeys=dkeys)
self.assertIsInstance(data["primary_artifact"], VanillaAE)
print(data["secondary_artifacts"])
self.assertIsInstance(data["secondary_artifacts"]["preproc"], Pipeline)
self.assertIsInstance(data["secondary_artifacts"]["postproc"], Pipeline)

Expand All @@ -130,12 +134,12 @@ def test_select_model_when_pytorch_model_exist2(self):
)
data = ml.load(skeys=skeys, dkeys=dkeys)
self.assertIsInstance(data["primary_artifact"], VanillaAE)
print(data["secondary_artifacts"])
self.assertIsInstance(data["secondary_artifacts"], list)

@patch("mlflow.sklearn.log_model", mock_log_model_sklearn)
@patch("mlflow.tracking.MlflowClient.transition_model_version_stage", mock_transition_stage)
@patch("mlflow.tracking.MlflowClient.get_latest_versions", mock_get_model_version)
@patch("mlflow.tracking.MlflowClient.search_model_versions", mock_list_of_model_version2)
@patch("mlflow.sklearn.load_model", Mock(return_value=RandomForestRegressor()))
@patch("mlflow.tracking.MlflowClient.get_run", Mock(return_value=return_empty_rundata()))
def test_select_model_when_sklearn_model_exist(self):
Expand Down Expand Up @@ -172,10 +176,6 @@ def test_select_model_with_version(self):
self.assertIsInstance(data["primary_artifact"], VanillaAE)
self.assertEqual(data["metadata"], None)

@patch("mlflow.pyfunc.log_model", mock_log_model_pytorch)
@patch("mlflow.log_param", mock_log_state_dict)
@patch("mlflow.tracking.MlflowClient.transition_model_version_stage", mock_transition_stage)
@patch("mlflow.tracking.MlflowClient.get_latest_versions", mock_get_model_version)
@patch("mlflow.pyfunc.load_model", Mock(side_effect=RuntimeError))
def test_select_model_when_no_model_01(self):
fake_skeys = ["Fakemodel_"]
Expand All @@ -185,10 +185,6 @@ def test_select_model_when_no_model_01(self):
ml.load(skeys=fake_skeys, dkeys=fake_dkeys)
self.assertTrue(log.output)

@patch("mlflow.tensorflow.log_model", mock_log_model_pytorch)
@patch("mlflow.log_param", mock_log_state_dict)
@patch("mlflow.tracking.MlflowClient.transition_model_version_stage", mock_transition_stage)
@patch("mlflow.tracking.MlflowClient.get_latest_versions", mock_get_model_version)
@patch("mlflow.tensorflow.load_model", Mock(side_effect=RuntimeError))
def test_select_model_when_no_model_02(self):
fake_skeys = ["Fakemodel_"]
Expand All @@ -198,6 +194,19 @@ def test_select_model_when_no_model_02(self):
ml.load(skeys=fake_skeys, dkeys=fake_dkeys)
self.assertTrue(log.output)

@patch("mlflow.tracking.MlflowClient.get_latest_versions", mock_get_model_version)
@patch(
"mlflow.tracking.MlflowClient.transition_model_version_stage",
Mock(side_effect=RuntimeError),
)
def test_transition_stage_fail(self):
fake_skeys = ["Fakemodel_"]
fake_dkeys = ["error"]
ml = MLflowRegistrar(TRACKING_URI, artifact_type="tensorflow")
with self.assertLogs(level="ERROR") as log:
ml.transition_stage(fake_skeys, fake_dkeys)
self.assertTrue(log.output)

def test_no_implementation(self):
with self.assertRaises(NotImplementedError):
MLflowRegistrar(TRACKING_URI, artifact_type="some_random")
Expand All @@ -206,6 +215,7 @@ def test_no_implementation(self):
@patch("mlflow.log_param", mock_log_state_dict)
@patch("mlflow.tracking.MlflowClient.transition_model_version_stage", mock_transition_stage)
@patch("mlflow.tracking.MlflowClient.get_latest_versions", mock_get_model_version)
@patch("mlflow.tracking.MlflowClient.search_model_versions", mock_list_of_model_version2)
@patch("mlflow.tracking.MlflowClient.delete_model_version", None)
@patch("mlflow.pytorch.load_model", Mock(side_effect=RuntimeError))
def test_delete_model_when_model_exist(self):
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "numalogic"
version = "0.2.2"
version = "0.2.3"
description = "Collection of operational Machine Learning models and tools."
authors = ["Numalogic Developers"]
packages = [{ include = "numalogic" }]
Expand Down

0 comments on commit 040584f

Please sign in to comment.