Skip to content

Commit

Permalink
Update parse model URI to prevent breaking old cases while supporting…
Browse files Browse the repository at this point in the history
… aliases (#8322)

* Update parse model URI to prevent breaking old cases while supporting aliases

Signed-off-by: Arpit Jasapara <arpit.jasapara@databricks.com>

* Applied comments

Signed-off-by: Arpit Jasapara <arpit.jasapara@databricks.com>

* Addressed comments

Signed-off-by: Arpit Jasapara <arpit.jasapara@databricks.com>

* Allow fragments false

Signed-off-by: Arpit Jasapara <arpit.jasapara@databricks.com>

---------

Signed-off-by: Arpit Jasapara <arpit.jasapara@databricks.com>
  • Loading branch information
arpitjasa-db committed Apr 26, 2023
1 parent 64270e2 commit ef7b6ed
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 30 deletions.
50 changes: 24 additions & 26 deletions mlflow/store/artifact/utils/models.py
@@ -1,4 +1,3 @@
import re
from typing import NamedTuple, Optional
import urllib.parse

Expand All @@ -8,12 +7,6 @@

_MODELS_URI_SUFFIX_LATEST = "latest"

# This regex is used by _parse_model_uri and details for the regex match
# can be found in _improper_model_uri_msg.
_MODEL_URI_REGEX = re.compile(
r"^\/(?P<model_name>[\w \.\-]+)(\/(?P<suffix>[\w]+))?(@(?P<alias>[\w\-]+))?$"
)


def is_using_databricks_registry(uri):
profile_uri = get_databricks_profile_uri_from_artifact_uri(uri) or mlflow.get_registry_uri()
Expand Down Expand Up @@ -59,32 +52,37 @@ def _parse_model_uri(uri):
- (name, None, None, None) to look for the latest of all versions.
- (name, None, None, alias) to look for a registered model alias.
"""
parsed = urllib.parse.urlparse(uri)
parsed = urllib.parse.urlparse(uri, allow_fragments=False)
if parsed.scheme != "models":
raise MlflowException(_improper_model_uri_msg(uri))
path = parsed.path
m = _MODEL_URI_REGEX.match(path)
if m is None:
if not path.startswith("/") or len(path) <= 1:
raise MlflowException(_improper_model_uri_msg(uri))
gd = m.groupdict()
model_name = gd.get("model_name")
suffix = gd.get("suffix")
alias = gd.get("alias")
if (model_name.strip() == "") or (suffix and alias) or (suffix is None and alias is None):

parts = path.lstrip("/").split("/")
if len(parts) > 2 or parts[0].strip() == "":
raise MlflowException(_improper_model_uri_msg(uri))

if alias:
# The URI is an alias URI, e.g. "models:/AdsModel1@Champion"
return ParsedModelUri(model_name, alias=alias)
elif suffix.isdigit():
# The suffix is a specific version, e.g. "models:/AdsModel1/123"
return ParsedModelUri(model_name, version=suffix)
elif suffix.lower() == _MODELS_URI_SUFFIX_LATEST.lower():
# The suffix is the 'latest' string (case insensitive), e.g. "models:/AdsModel1/latest"
return ParsedModelUri(model_name)
if len(parts) == 2:
name, suffix = parts
if suffix.strip() == "":
raise MlflowException(_improper_model_uri_msg(uri))
# The URI is in the suffix format
if suffix.isdigit():
# The suffix is a specific version, e.g. "models:/AdsModel1/123"
return ParsedModelUri(name, version=suffix)
elif suffix.lower() == _MODELS_URI_SUFFIX_LATEST.lower():
# The suffix is the 'latest' string (case insensitive), e.g. "models:/AdsModel1/latest"
return ParsedModelUri(name)
else:
# The suffix is a specific stage (case insensitive), e.g. "models:/AdsModel1/Production"
return ParsedModelUri(name, stage=suffix)
else:
# The suffix is a specific stage (case insensitive), e.g. "models:/AdsModel1/Production"
return ParsedModelUri(model_name, stage=suffix)
# The URI is an alias URI, e.g. "models:/AdsModel1@Champion"
alias_parts = parts[0].rsplit("@", 1)
if len(alias_parts) != 2 or alias_parts[1].strip() == "":
raise MlflowException(_improper_model_uri_msg(uri))
return ParsedModelUri(alias_parts[0], alias=alias_parts[1])


def get_model_name_and_version(client, models_uri):
Expand Down
20 changes: 16 additions & 4 deletions tests/store/artifact/utils/test_model_utils.py
Expand Up @@ -12,6 +12,7 @@
[
("models:/AdsModel1/0", "AdsModel1", "0"),
("models:/Ads Model 1/12345", "Ads Model 1", "12345"),
("models://////Ads Model 1/12345", "Ads Model 1", "12345"), # many slashes
("models:/12345/67890", "12345", "67890"),
("models://profile@databricks/12345/67890", "12345", "67890"),
("models:/catalog.schema.model/0", "catalog.schema.model", "0"), # UC Model format
Expand All @@ -32,7 +33,14 @@ def test_parse_models_uri_with_version(uri, expected_name, expected_version):
("models:/AdsModel1/production", "AdsModel1", "production"), # case insensitive
("models:/AdsModel1/pROduction", "AdsModel1", "pROduction"), # case insensitive
("models:/Ads Model 1/None", "Ads Model 1", "None"),
("models://////Ads Model 1/Staging", "Ads Model 1", "Staging"), # many slashes
("models://scope:key@databricks/Ads Model 1/None", "Ads Model 1", "None"),
(
"models:/Name/Stage@Alias",
"Name",
"Stage@Alias",
), # technically allowed, but the backend would throw
("models:/Name@Alias/Stage", "Name@Alias", "Stage"),
],
)
def test_parse_models_uri_with_stage(uri, expected_name, expected_stage):
Expand All @@ -50,6 +58,7 @@ def test_parse_models_uri_with_stage(uri, expected_name, expected_stage):
("models:/AdsModel1/Latest", "AdsModel1"), # case insensitive
("models:/AdsModel1/LATEST", "AdsModel1"), # case insensitive
("models:/Ads Model 1/latest", "Ads Model 1"),
("models://////Ads Model 1/latest", "Ads Model 1"), # many slashes
("models://scope:key@databricks/Ads Model 1/latest", "Ads Model 1"),
("models:/catalog.schema.model/latest", "catalog.schema.model"), # UC Model format
],
Expand All @@ -70,7 +79,14 @@ def test_parse_models_uri_with_latest(uri, expected_name):
("models:/AdsModel1@cHAmpion", "AdsModel1", "cHAmpion"), # case insensitive
("models:/Ads Model 1@challenger", "Ads Model 1", "challenger"),
("models://scope:key/Ads Model 1@None", "Ads Model 1", "None"),
("models://////Ads Model 1@TestAlias", "Ads Model 1", "TestAlias"), # many slashes
("models:/catalog.schema.model@None", "catalog.schema.model", "None"), # UC Model format
("models:/A!&#$%;{}()[]CrazyName@Alias", "A!&#$%;{}()[]CrazyName", "Alias"),
(
"models:/NameWith@IntheMiddle@Alias",
"NameWith@IntheMiddle",
"Alias",
), # check for model name with alias
],
)
def test_parse_models_uri_with_alias(uri, expected_name, expected_alias):
Expand All @@ -93,12 +109,8 @@ def test_parse_models_uri_with_alias(uri, expected_name, expected_alias):
"models:/Name/", # empty suffix
"models:/Name@", # empty alias
"models:/Name/Stage/0", # too many specifiers
"models:/Name/Stage@Alias", # stage and alias both specified
"models:/Name@alias/Stage", # Stage and alias both specified
"models:/Name@Alias@other", # too many aliases
"models:Name/Stage", # missing slash
"models://Name/Stage", # hostnames are ignored, path too short
"models://Name@te#ty;", # invalid characters
],
)
def test_parse_models_uri_invalid_input(uri):
Expand Down

0 comments on commit ef7b6ed

Please sign in to comment.