Skip to content

Commit

Permalink
Disallow local source for model version (#7908)
Browse files Browse the repository at this point in the history
* Disallow local source for model version

Signed-off-by: harupy <hkawamura0130@gmail.com>

* Fix tests

Signed-off-by: harupy <hkawamura0130@gmail.com>

* Fix is_local_uri

Signed-off-by: harupy <hkawamura0130@gmail.com>

* Use runs URI

Signed-off-by: harupy <hkawamura0130@gmail.com>

* Fix tests

Signed-off-by: harupy <hkawamura0130@gmail.com>

---------

Signed-off-by: harupy <hkawamura0130@gmail.com>
  • Loading branch information
harupy authored and dbczumar committed Feb 28, 2023
1 parent 6b0c333 commit d622ee7
Show file tree
Hide file tree
Showing 7 changed files with 77 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,9 @@ public class ModelRegistryMlflowClientTest {
private final TestClientProvider testClientProvider = new TestClientProvider();

private MlflowClient client;
private String source;

private String modelName;
private File tempDir;
private File tempFile;

private static final String content = "Hello, Worldz!";

Expand All @@ -60,16 +59,18 @@ public void before() throws IOException {

RunInfo runCreated = client.createRun(expId);
String runId = runCreated.getRunUuid();
source = String.format("runs:/%s/model", runId);

tempDir = Files.createTempDirectory("tempDir").toFile();
tempFile = Files.createTempFile(tempDir.toPath(), "file", ".txt").toFile();

File tempDir = Files.createTempDirectory("tempDir").toFile();
File tempFile = Files.createTempFile(tempDir.toPath(), "file", ".txt").toFile();
FileUtils.writeStringToFile(tempFile, content, StandardCharsets.UTF_8);
client.logArtifact(runId, tempFile, "model");

client.sendPost("registered-models/create",
mapper.makeCreateModel(modelName));

client.sendPost("model-versions/create",
mapper.makeCreateModelVersion(modelName, runId, tempDir.getAbsolutePath()));
mapper.makeCreateModelVersion(modelName, runId, String.format("runs:/%s/model", runId)));
}

@AfterTest
Expand Down Expand Up @@ -121,7 +122,7 @@ public void testGetRegisteredModel() {
@Test
public void testGetModelVersionDownloadUri() {
String downloadUri = client.getModelVersionDownloadUri(modelName, "1");
Assert.assertEquals(tempDir.getAbsolutePath(), downloadUri);
Assert.assertEquals(source, downloadUri);
}

@Test
Expand Down Expand Up @@ -169,14 +170,14 @@ public void testSearchModelVersions() {

// create new model version of existing registered model
String newVersionRunId = "newVersionRunId";
String newVersionSource = "newVersionSource";
String newVersionSource = "runs:/newVersionRunId/model";
client.sendPost("model-versions/create",
mapper.makeCreateModelVersion(modelName, newVersionRunId, newVersionSource));

// create new registered model
String modelName2 = "modelName2";
String runId2 = "runId2";
String source2 = "source2";
String source2 = "runs:/runId2/model";
client.sendPost("registered-models/create",
mapper.makeCreateModel(modelName2));
client.sendPost("model-versions/create",
Expand Down
8 changes: 8 additions & 0 deletions mlflow/server/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@
from mlflow.utils.proto_json_utils import message_to_json, parse_dict
from mlflow.utils.validation import _validate_batch_log_api_req
from mlflow.utils.string_utils import is_string_type
from mlflow.utils.uri import is_local_uri
from mlflow.tracking.registry import UnsupportedModelRegistryStoreURIException

_logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -1376,6 +1377,13 @@ def _create_model_version():
"description": [_assert_string],
},
)

if is_local_uri(request_message.source):
raise MlflowException(
f"Model version source cannot be a local path: '{request_message.source}'",
INVALID_PARAMETER_VALUE,
)

model_version = _get_model_registry_store().create_model_version(
name=request_message.name,
source=request_message.source,
Expand Down
10 changes: 9 additions & 1 deletion mlflow/utils/uri.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,19 @@ def is_local_uri(uri):
"""Returns true if this is a local file path (/foo or file:/foo)."""
if uri == "databricks":
return False

parsed_uri = urllib.parse.urlparse(uri)
if parsed_uri.hostname:
return False

scheme = parsed_uri.scheme
return scheme == "" or scheme == "file"
if scheme == "" or scheme == "file":
return True

if is_windows() and len(scheme) == 1 and scheme.lower() == pathlib.Path(uri).drive.lower()[0]:
return True

return False


def is_http_uri(uri):
Expand Down
4 changes: 2 additions & 2 deletions tests/server/test_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,7 +457,7 @@ def test_create_model_version(mock_get_request_message, mock_model_registry_stor
run_link = "localhost:5000/path/to/run"
mock_get_request_message.return_value = CreateModelVersion(
name="model_1",
source="A/B",
source=f"runs:/{run_id}",
run_id=run_id,
run_link=run_link,
tags=[tag.to_proto() for tag in tags],
Expand All @@ -469,7 +469,7 @@ def test_create_model_version(mock_get_request_message, mock_model_registry_stor
resp = _create_model_version()
_, args = mock_model_registry_store.create_model_version.call_args
assert args["name"] == "model_1"
assert args["source"] == "A/B"
assert args["source"] == f"runs:/{run_id}"
assert args["run_id"] == run_id
assert {tag.key: tag.value for tag in args["tags"]} == {tag.key: tag.value for tag in tags}
assert args["run_link"] == run_link
Expand Down
36 changes: 18 additions & 18 deletions tests/tracking/test_model_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ def test_search_model_versions_filter_string(
for name in names + names[:10]:
# Sleep for unique creation_time to make search results deterministic
time.sleep(0.001)
mvs.append(client.create_model_version(name, "path/to/model", "run_id"))
mvs.append(client.create_model_version(name, "runs:/run_id/model", "run_id"))
for mv in mvs:
assert isinstance(mv, ModelVersion)
mvs = mvs[::-1]
Expand Down Expand Up @@ -313,7 +313,7 @@ def test_search_model_versions_max_results(client, max_results):
for name in names + names[:10]:
# Sleep for unique creation_time to make search results deterministic
time.sleep(0.001)
mvs.append(client.create_model_version(name, "path/to/model", "run_id"))
mvs.append(client.create_model_version(name, "runs:/run_id/model", "run_id"))
for mv in mvs:
assert isinstance(mv, ModelVersion)
mvs = mvs[::-1]
Expand Down Expand Up @@ -363,7 +363,7 @@ def test_search_model_versions_order_by(
for name in names + names[:10]:
# Sleep for unique creation_time to make search results deterministic
time.sleep(0.001)
mvs.append(client.create_model_version(name, "path/to/model", "run_id"))
mvs.append(client.create_model_version(name, "runs:/run_id/model", "run_id"))
for mv in mvs:
assert isinstance(mv, ModelVersion)
mvs = mvs[::-1]
Expand Down Expand Up @@ -397,7 +397,7 @@ def test_create_and_query_model_version_flow(client):
name = "CreateMVTest"
tags = {"key": "value", "another key": "some other value", "numeric value": 12345}
client.create_registered_model(name)
mv1 = client.create_model_version(name, "path/to/model", "run_id_1", tags)
mv1 = client.create_model_version(name, "runs:/run_id/model", "run_id_1", tags)
assert mv1.version == "1"
assert mv1.name == name
assert mv1.tags == {"key": "value", "another key": "some other value", "numeric value": "12345"}
Expand All @@ -410,7 +410,7 @@ def test_create_and_query_model_version_flow(client):
assert [rm.latest_versions for rm in client.search_registered_models() if rm.name == name] == [
[mvd1]
]
mv2 = client.create_model_version(name, "another_path/to/model", "run_id_1")
mv2 = client.create_model_version(name, "runs:/run_id/another_model", "run_id_1")
assert mv2.version == "2"
assert mv2.name == name
mvd2 = client.get_model_version(name, "2")
Expand All @@ -421,18 +421,18 @@ def test_create_and_query_model_version_flow(client):
assert {mv.version for mv in model_versions_by_name} == {"1", "2"}
assert {mv.name for mv in model_versions_by_name} == {name}

mv3 = client.create_model_version(name, "another_path/to/model", "run_id_2")
mv3 = client.create_model_version(name, "runs:/run_id/another_model", "run_id_2")
assert mv3.version == "3"
assert client.search_model_versions("source_path = 'path/to/model'") == [mvd1]
assert client.search_model_versions("source_path = 'runs:/run_id/model'") == [mvd1]
assert client.search_model_versions("run_id = 'run_id_1'") == [mvd2, mvd1]

assert client.get_model_version_download_uri(name, "1") == "path/to/model"
assert client.get_model_version_download_uri(name, "1") == "runs:/run_id/model"


def test_get_model_version(client):
name = "GetModelVersionTest"
client.create_registered_model(name)
client.create_model_version(name, "path/to/model", "run_id_1")
client.create_model_version(name, "runs:/run_id/model", "run_id_1")
model_version = client.get_model_version(name, "1")
assert model_version.name == name
assert model_version.version == "1"
Expand All @@ -453,7 +453,7 @@ def test_update_model_version_flow(client):
assert_is_between(start_time_0, end_time_0, rmd1.last_updated_timestamp)

start_time_1 = get_current_time_millis()
mv1 = client.create_model_version(name, "path/to/model", "run_id_1")
mv1 = client.create_model_version(name, "runs:/run_id/model", "run_id_1")
end_time_1 = get_current_time_millis()
assert mv1.version == "1"
assert mv1.name == name
Expand All @@ -470,7 +470,7 @@ def test_update_model_version_flow(client):
assert [rm.latest_versions for rm in client.search_registered_models() if rm.name == name] == [
[mvd1]
]
mv2 = client.create_model_version(name, "another_path/to/model", "run_id_1")
mv2 = client.create_model_version(name, "runs:/run_id/another_model", "run_id_1")
assert mv2.version == "2"
assert mv2.name == name
mvd2 = client.get_model_version(name, "2")
Expand Down Expand Up @@ -529,7 +529,7 @@ def test_latest_models(client):
for version, stage in version_stage_mapping:
# Sleep for unique creation_time to make search results deterministic
time.sleep(0.001)
mv = client.create_model_version(name, "path/to/model", "run_id")
mv = client.create_model_version(name, "runs:/run_id/model", "run_id")
assert mv.version == version
if stage != "None":
client.transition_model_version_stage(name, version, stage=stage)
Expand Down Expand Up @@ -557,7 +557,7 @@ def test_delete_model_version_flow(client):
assert_is_between(start_time_0, end_time_0, rmd1.last_updated_timestamp)

start_time_1 = get_current_time_millis()
mv1 = client.create_model_version(name, "path/to/model", "run_id_1")
mv1 = client.create_model_version(name, "runs:/run_id/model", "run_id_1")
end_time_1 = get_current_time_millis()
assert mv1.version == "1"
assert mv1.name == name
Expand All @@ -570,10 +570,10 @@ def test_delete_model_version_flow(client):
assert_is_between(start_time_0, end_time_0, rmd2.creation_timestamp)
assert_is_between(start_time_1, end_time_1, rmd2.last_updated_timestamp)

mv2 = client.create_model_version(name, "another_path/to/model", "run_id_1")
mv2 = client.create_model_version(name, "runs:/run_id/another_model", "run_id_1")
assert mv2.version == "2"
assert mv2.name == name
mv3 = client.create_model_version(name, "a/b/c", "run_id_2")
mv3 = client.create_model_version(name, "runs:/run_id_2/a/b/c", "run_id_2")
assert mv3.version == "3"
assert mv3.name == name
model_versions_detailed = [
Expand Down Expand Up @@ -613,7 +613,7 @@ def test_delete_model_version_flow(client):
assert {mv.version for mv in client.search_model_versions("name = '%s'" % name)} == {"2"}

# new model versions will not reuse existing version numbers
mv4 = client.create_model_version(name, "a/b/c", "run_id_2")
mv4 = client.create_model_version(name, "runs:/run_id_2/a/b/c", "run_id_2")
assert mv4.version == "4"
assert mv4.name == name
assert {mv.version for mv in client.search_model_versions("name = '%s'" % name)} == {
Expand All @@ -625,7 +625,7 @@ def test_delete_model_version_flow(client):
def test_set_delete_model_version_tag_flow(client):
name = "SetDeleteMVTagTest"
client.create_registered_model(name)
client.create_model_version(name, "path/to/model", "run_id_1")
client.create_model_version(name, "runs:/run_id/model", "run_id_1")
model_version_detailed = client.get_model_version(name, "1")
assert model_version_detailed.tags == {}
tags = {"key": "value", "numeric value": 12345}
Expand All @@ -641,6 +641,6 @@ def test_set_delete_model_version_tag_flow(client):
def test_set_model_version_tag_with_empty_string_as_value(client):
name = "SetMVTagEmptyValueTest"
client.create_registered_model(name)
client.create_model_version(name, "path/to/model", "run_id_1")
client.create_model_version(name, "runs:/run_id/model", "run_id_1")
client.set_model_version_tag(name, "1", "tag_key", "")
assert {"tag_key": ""}.items() <= client.get_model_version(name, "1").tags.items()
18 changes: 18 additions & 0 deletions tests/tracking/test_rest_tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -1042,3 +1042,21 @@ def get(self, key, default=None):
metric_key="mock_key",
max_results=25000,
)


def test_create_model_version_with_local_source(mlflow_client):
name = "mode"
mlflow_client.create_registered_model(name)
response = requests.post(
f"{mlflow_client.tracking_uri}/api/2.0/mlflow/model-versions/create",
json={
"name": name,
"source": "file:///tmp/model",
"run_id": "run_id",
},
)
assert response.status_code == 400
assert response.json() == {
"error_code": "INVALID_PARAMETER_VALUE",
"message": "Model version source cannot be a local path: 'file:///tmp/model'",
}
13 changes: 12 additions & 1 deletion tests/utils/test_uri.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def test_get_db_info_from_uri_errors_invalid_profile(server_uri):
get_db_info_from_uri(server_uri)


def test_uri_types():
def test_is_local_uri():
assert is_local_uri("mlruns")
assert is_local_uri("./mlruns")
assert is_local_uri("file:///foo/mlruns")
Expand All @@ -99,12 +99,23 @@ def test_uri_types():
assert not is_local_uri("databricks:whatever")
assert not is_local_uri("databricks://whatever")


@pytest.mark.skipif(not is_windows(), reason="Windows-only test")
def test_is_local_uri_windows():
assert is_local_uri("C:\\foo\\mlruns")
assert is_local_uri("C:/foo/mlruns")
assert is_local_uri("file:///C:\\foo\\mlruns")


def test_is_databricks_uri():
assert is_databricks_uri("databricks")
assert is_databricks_uri("databricks:whatever")
assert is_databricks_uri("databricks://whatever")
assert not is_databricks_uri("mlruns")
assert not is_databricks_uri("http://whatever")


def test_is_http_uri():
assert is_http_uri("http://whatever")
assert is_http_uri("https://whatever")
assert not is_http_uri("file://whatever")
Expand Down

0 comments on commit d622ee7

Please sign in to comment.