Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
Signed-off-by: Ann Zhang <ann.zhang@databricks.com>
  • Loading branch information
annzhang-db committed Nov 21, 2023
1 parent 6c100af commit 99200f1
Showing 1 changed file with 6 additions and 17 deletions.
Expand Up @@ -5,9 +5,8 @@
from mlflow.protos.databricks_pb2 import INVALID_PARAMETER_VALUE
from mlflow.tracking.default_experiment.databricks_notebook_experiment_provider import (
DatabricksNotebookExperimentProvider,
DatabricksRepoNotebookExperimentProvider,
)
from mlflow.utils.mlflow_tags import MLFLOW_EXPERIMENT_SOURCE_ID, MLFLOW_EXPERIMENT_SOURCE_TYPE
from mlflow.utils.mlflow_tags import MLFLOW_EXPERIMENT_SOURCE_ID


def test_databricks_notebook_default_experiment_in_context():
Expand All @@ -32,17 +31,6 @@ def test_databricks_notebook_default_experiment_id():
)


def test_databricks_repo_notebook_default_experiment_in_context():
with mock.patch(
"mlflow.utils.databricks_utils.is_in_databricks_repo_notebook", return_value=True
):
assert DatabricksRepoNotebookExperimentProvider().in_context()
with mock.patch(
"mlflow.utils.databricks_utils.is_in_databricks_repo_notebook", return_value=False
):
assert not DatabricksRepoNotebookExperimentProvider().in_context()


def test_databricks_repo_notebook_default_experiment_gets_id_by_request():
with mock.patch(
"mlflow.utils.databricks_utils.get_notebook_id",
Expand All @@ -53,9 +41,10 @@ def test_databricks_repo_notebook_default_experiment_gets_id_by_request():
), mock.patch.object(
MlflowClient, "create_experiment", return_value="experiment_id"
) as create_experiment_mock:
returned_id = DatabricksRepoNotebookExperimentProvider().get_experiment_id()
DatabricksNotebookExperimentProvider._resolved_notebook_experiment_id = None
returned_id = DatabricksNotebookExperimentProvider().get_experiment_id()
assert returned_id == "experiment_id"
tags = {MLFLOW_EXPERIMENT_SOURCE_TYPE: "REPO_NOTEBOOK", MLFLOW_EXPERIMENT_SOURCE_ID: 1234}
tags = {MLFLOW_EXPERIMENT_SOURCE_ID: 1234}
create_experiment_mock.assert_called_once_with("/Repos/path", None, tags)


Expand All @@ -69,9 +58,9 @@ def test_databricks_repo_notebook_default_experiment_uses_fallback_notebook_id()
), mock.patch.object(
MlflowClient, "create_experiment"
) as create_experiment_mock:
DatabricksRepoNotebookExperimentProvider._resolved_repo_notebook_experiment_id = None
DatabricksNotebookExperimentProvider._resolved_notebook_experiment_id = None
create_experiment_mock.side_effect = MlflowException(
message="not enabled", error_code=INVALID_PARAMETER_VALUE
)
returned_id = DatabricksRepoNotebookExperimentProvider().get_experiment_id()
returned_id = DatabricksNotebookExperimentProvider().get_experiment_id()
assert returned_id == 1234

0 comments on commit 99200f1

Please sign in to comment.