/
test_model_utils.py
128 lines (113 loc) · 4.54 KB
/
test_model_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import pytest
from unittest import mock
from mlflow.exceptions import MlflowException
from mlflow.store.artifact.utils.models import _parse_model_uri, get_model_name_and_version
from mlflow.tracking import MlflowClient
from mlflow.entities.model_registry import ModelVersion
@pytest.mark.parametrize(
"uri, expected_name, expected_version",
[
("models:/AdsModel1/0", "AdsModel1", 0),
("models:/Ads Model 1/12345", "Ads Model 1", 12345),
("models:/12345/67890", "12345", 67890),
("models://profile@databricks/12345/67890", "12345", 67890),
],
)
def test_parse_models_uri_with_version(uri, expected_name, expected_version):
(name, version, stage) = _parse_model_uri(uri)
assert name == expected_name
assert version == expected_version
assert stage is None
@pytest.mark.parametrize(
"uri, expected_name, expected_stage",
[
("models:/AdsModel1/Production", "AdsModel1", "Production"),
("models:/AdsModel1/production", "AdsModel1", "production"), # case insensitive
("models:/AdsModel1/pROduction", "AdsModel1", "pROduction"), # case insensitive
("models:/Ads Model 1/None", "Ads Model 1", "None"),
("models://scope:key@databricks/Ads Model 1/None", "Ads Model 1", "None"),
],
)
def test_parse_models_uri_with_stage(uri, expected_name, expected_stage):
(name, version, stage) = _parse_model_uri(uri)
assert name == expected_name
assert version is None
assert stage == expected_stage
@pytest.mark.parametrize(
"uri, expected_name",
[
("models:/AdsModel1/latest", "AdsModel1"),
("models:/AdsModel1/Latest", "AdsModel1"), # case insensitive
("models:/AdsModel1/LATEST", "AdsModel1"), # case insensitive
("models:/Ads Model 1/latest", "Ads Model 1"),
("models://scope:key@databricks/Ads Model 1/latest", "Ads Model 1"),
],
)
def test_parse_models_uri_with_latest(uri, expected_name):
(name, version, stage) = _parse_model_uri(uri)
assert name == expected_name
assert version is None
assert stage is None
@pytest.mark.parametrize(
"uri",
[
"notmodels:/NameOfModel/12345", # wrong scheme with version
"notmodels:/NameOfModel/StageName", # wrong scheme with stage
"models:/", # no model name
"models:/Name/Stage/0", # too many specifiers
"models:Name/Stage", # missing slash
"models://Name/Stage", # hostnames are ignored, path too short
],
)
def test_parse_models_uri_invalid_input(uri):
with pytest.raises(MlflowException, match="Not a proper models"):
_parse_model_uri(uri)
def test_get_model_name_and_version_with_version():
with mock.patch.object(
MlflowClient, "get_latest_versions", return_value=[]
) as mlflow_client_mock:
assert get_model_name_and_version(MlflowClient(), "models:/AdsModel1/123") == (
"AdsModel1",
"123",
)
mlflow_client_mock.assert_not_called()
def test_get_model_name_and_version_with_stage():
with mock.patch.object(
MlflowClient,
"get_latest_versions",
return_value=[
ModelVersion(
name="mv1", version="10", creation_timestamp=123, current_stage="Production"
),
ModelVersion(
name="mv2", version="15", creation_timestamp=124, current_stage="Production"
),
],
) as mlflow_client_mock:
assert get_model_name_and_version(MlflowClient(), "models:/AdsModel1/Production") == (
"AdsModel1",
"15",
)
mlflow_client_mock.assert_called_once_with("AdsModel1", ["Production"])
def test_get_model_name_and_version_with_latest():
with mock.patch.object(
MlflowClient,
"get_latest_versions",
return_value=[
ModelVersion(
name="mv1", version="10", creation_timestamp=123, current_stage="Production"
),
ModelVersion(name="mv3", version="20", creation_timestamp=125, current_stage="None"),
ModelVersion(name="mv2", version="15", creation_timestamp=124, current_stage="Staging"),
],
) as mlflow_client_mock:
assert get_model_name_and_version(MlflowClient(), "models:/AdsModel1/latest") == (
"AdsModel1",
"20",
)
mlflow_client_mock.assert_called_once_with("AdsModel1", None)
# Check that "latest" is case insensitive.
assert get_model_name_and_version(MlflowClient(), "models:/AdsModel1/lATest") == (
"AdsModel1",
"20",
)