Skip to content

Commit

Permalink
Improve UC model registry client error messages when specifying nonex…
Browse files Browse the repository at this point in the history
…istent source files or missing Python dependencies (#8324)

* Improve UC model registry client error messages when specifying nonexistent source files or missing Python dependencies

Signed-off-by: Sid Murching <sid.murching@databricks.com>

* Fix lint

Signed-off-by: Sid Murching <sid.murching@databricks.com>

* Fix windows test?

Signed-off-by: Sid Murching <sid.murching@databricks.com>

---------

Signed-off-by: Sid Murching <sid.murching@databricks.com>
  • Loading branch information
smurching committed Apr 26, 2023
1 parent 8860da3 commit 64270e2
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 3 deletions.
14 changes: 11 additions & 3 deletions mlflow/store/_unity_catalog/registry/rest_store.py
Expand Up @@ -387,9 +387,17 @@ def create_model_version(
)
)
with tempfile.TemporaryDirectory() as tmpdir:
local_model_dir = mlflow.artifacts.download_artifacts(
artifact_uri=source, dst_path=tmpdir, tracking_uri=self.tracking_uri
)
try:
local_model_dir = mlflow.artifacts.download_artifacts(
artifact_uri=source, dst_path=tmpdir, tracking_uri=self.tracking_uri
)
except Exception as e:
raise MlflowException(
f"Unable to download model artifacts from source artifact location "
f"'{source}' in order to upload them to Unity Catalog. Please ensure "
f"the source artifact location exists and that you can download from "
f"it via mlflow.artifacts.download_artifacts()"
) from e
self._validate_model_signature(local_model_dir)
model_version = self._call_endpoint(CreateModelVersionRequest, req_body).model_version
version_number = model_version.version
Expand Down
18 changes: 18 additions & 0 deletions mlflow/store/_unity_catalog/registry/utils.py
Expand Up @@ -49,8 +49,26 @@ def get_artifact_repo_from_storage_info(
:param storage_location: Storage location of the model version
:param scoped_token: Protobuf scoped token to use to authenticate to blob storage
"""
try:
return _get_artifact_repo_from_storage_info(
storage_location=storage_location, scoped_token=scoped_token
)
except ImportError as e:
raise MlflowException(
"Unable to import necessary dependencies to access model version files in "
"Unity Catalog. Please ensure you have the necessary dependencies installed, "
"e.g. by running 'pip install mlflow[databricks]' or "
"'pip install mlflow-skinny[databricks]'"
) from e


def _get_artifact_repo_from_storage_info(
storage_location: str, scoped_token: TemporaryCredentials
) -> ArtifactRepository:
credential_type = scoped_token.WhichOneof("credentials")
if credential_type == "aws_temp_credentials":
# Verify upfront that boto3 is importable
import boto3 # pylint: disable=unused-import
from mlflow.store.artifact.s3_artifact_repo import S3ArtifactRepository

aws_creds = scoped_token.aws_temp_credentials
Expand Down
Expand Up @@ -129,6 +129,46 @@ def local_model_dir(tmp_path):
yield tmp_path


def test_create_model_version_nonexistent_directory(store, tmp_path):
fake_directory = str(tmp_path.joinpath("myfakepath"))
with pytest.raises(
MlflowException,
match="Unable to download model artifacts from source artifact location",
):
store.create_model_version(name="mymodel", source=fake_directory)


def test_create_model_version_missing_python_deps(store, local_model_dir):
access_key_id = "fake-key"
secret_access_key = "secret-key"
session_token = "session-token"
aws_temp_creds = TemporaryCredentials(
aws_temp_credentials=AwsCredentials(
access_key_id=access_key_id,
secret_access_key=secret_access_key,
session_token=session_token,
)
)
storage_location = "s3://blah"
source = str(local_model_dir)
model_name = "model_1"
version = "1"
with mock.patch(
"mlflow.utils.rest_utils.http_request",
side_effect=get_request_mock(
name=model_name,
version=version,
temp_credentials=aws_temp_creds,
storage_location=storage_location,
source=source,
),
), mock.patch.dict("sys.modules", {"boto3": None}), pytest.raises(
MlflowException,
match="Unable to import necessary dependencies to access model version files",
):
store.create_model_version(name=model_name, source=str(local_model_dir))


def test_create_model_version_missing_mlmodel(store, tmp_path):
with pytest.raises(
MlflowException,
Expand Down

0 comments on commit 64270e2

Please sign in to comment.