Skip to content

Commit

Permalink
Fix the bug of stages in models URI being case-sensitive (#5312)
Browse files Browse the repository at this point in the history
Fix the bug of stages in models URI being case-sensitive
  • Loading branch information
lichenran1234 authored and dbczumar committed Jan 26, 2022
1 parent 86ad040 commit 271750b
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 8 deletions.
9 changes: 3 additions & 6 deletions mlflow/store/artifact/utils/models.py
Expand Up @@ -3,7 +3,6 @@
import mlflow.tracking
from mlflow.exceptions import MlflowException
from mlflow.utils.uri import get_databricks_profile_uri_from_artifact_uri, is_databricks_uri
from mlflow.entities.model_registry.model_version_stages import ALL_STAGES

_MODELS_URI_SUFFIX_LATEST = "latest"

Expand Down Expand Up @@ -60,13 +59,11 @@ def _parse_model_uri(uri):
if parts[1].isdigit():
# The suffix is a specific version, e.g. "models:/AdsModel1/123"
return parts[0], int(parts[1]), None
elif parts[1] == _MODELS_URI_SUFFIX_LATEST:
# The suffix is exactly the 'latest' string, e.g. "models:/AdsModel1/latest"
elif parts[1].lower() == _MODELS_URI_SUFFIX_LATEST.lower():
# The suffix is the 'latest' string (case insensitive), e.g. "models:/AdsModel1/latest"
return parts[0], None, None
elif parts[1] not in ALL_STAGES:
raise MlflowException(_improper_model_uri_msg(uri))
else:
# The suffix is a specific stage, e.g. "models:/AdsModel1/Production"
# The suffix is a specific stage (case insensitive), e.g. "models:/AdsModel1/Production"
return parts[0], None, parts[1]


Expand Down
11 changes: 9 additions & 2 deletions tests/store/artifact/utils/test_model_utils.py
Expand Up @@ -27,6 +27,8 @@ def test_parse_models_uri_with_version(uri, expected_name, expected_version):
"uri, expected_name, expected_stage",
[
("models:/AdsModel1/Production", "AdsModel1", "Production"),
("models:/AdsModel1/production", "AdsModel1", "production"), # case insensitive
("models:/AdsModel1/pROduction", "AdsModel1", "pROduction"), # case insensitive
("models:/Ads Model 1/None", "Ads Model 1", "None"),
("models://scope:key@databricks/Ads Model 1/None", "Ads Model 1", "None"),
],
Expand All @@ -42,6 +44,8 @@ def test_parse_models_uri_with_stage(uri, expected_name, expected_stage):
"uri, expected_name",
[
("models:/AdsModel1/latest", "AdsModel1"),
("models:/AdsModel1/Latest", "AdsModel1"), # case insensitive
("models:/AdsModel1/LATEST", "AdsModel1"), # case insensitive
("models:/Ads Model 1/latest", "Ads Model 1"),
("models://scope:key@databricks/Ads Model 1/latest", "Ads Model 1"),
],
Expand All @@ -60,8 +64,6 @@ def test_parse_models_uri_with_latest(uri, expected_name):
"notmodels:/NameOfModel/StageName", # wrong scheme with stage
"models:/", # no model name
"models:/Name/Stage/0", # too many specifiers
"models:/Name/production", # should be 'Production'
"models:/Name/LATEST", # not lower case 'latest'
"models:Name/Stage", # missing slash
"models://Name/Stage", # hostnames are ignored, path too short
],
Expand Down Expand Up @@ -119,3 +121,8 @@ def test_get_model_name_and_version_with_latest():
"20",
)
mlflow_client_mock.assert_called_once_with("AdsModel1", None)
# Check that "latest" is case insensitive.
assert get_model_name_and_version(MlflowClient(), "models:/AdsModel1/lATest") == (
"AdsModel1",
"20",
)
12 changes: 12 additions & 0 deletions tests/store/model_registry/test_sqlalchemy_store.py
Expand Up @@ -347,6 +347,18 @@ def test_get_latest_versions(self):
),
{"Production": 3},
)
self.assertEqual(
self._extract_latest_by_stage(
self.store.get_latest_versions(name=name, stages=["production"])
),
{"Production": 3},
) # The stages are case insensitive.
self.assertEqual(
self._extract_latest_by_stage(
self.store.get_latest_versions(name=name, stages=["pROduction"])
),
{"Production": 3},
) # The stages are case insensitive.
self.assertEqual(
self._extract_latest_by_stage(
self.store.get_latest_versions(name=name, stages=["None", "Production"])
Expand Down

0 comments on commit 271750b

Please sign in to comment.