Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pass in sourceId tag in all cases #10464

Merged
merged 14 commits into from Dec 10, 2023
Expand Up @@ -3,55 +3,42 @@
from mlflow.tracking.client import MlflowClient
from mlflow.tracking.default_experiment.abstract_context import DefaultExperimentProvider
from mlflow.utils import databricks_utils
from mlflow.utils.mlflow_tags import (
MLFLOW_EXPERIMENT_SOURCE_ID,
MLFLOW_EXPERIMENT_SOURCE_TYPE,
)
from mlflow.utils.mlflow_tags import MLFLOW_EXPERIMENT_SOURCE_ID, MLFLOW_EXPERIMENT_SOURCE_TYPE


class DatabricksNotebookExperimentProvider(DefaultExperimentProvider):
def in_context(self):
return databricks_utils.is_in_databricks_notebook()

def get_experiment_id(self):
return databricks_utils.get_notebook_id()


class DatabricksRepoNotebookExperimentProvider(DefaultExperimentProvider):
_resolved_repo_notebook_experiment_id = None
_resolved_notebook_experiment_id = None

def in_context(self):
return databricks_utils.is_in_databricks_repo_notebook()
return databricks_utils.is_in_databricks_notebook()

def get_experiment_id(self):
if DatabricksRepoNotebookExperimentProvider._resolved_repo_notebook_experiment_id:
return DatabricksRepoNotebookExperimentProvider._resolved_repo_notebook_experiment_id
if DatabricksNotebookExperimentProvider._resolved_notebook_experiment_id:
return DatabricksNotebookExperimentProvider._resolved_notebook_experiment_id

source_notebook_id = databricks_utils.get_notebook_id()
source_notebook_name = databricks_utils.get_notebook_path()
tags = {
MLFLOW_EXPERIMENT_SOURCE_TYPE: "REPO_NOTEBOOK",
MLFLOW_EXPERIMENT_SOURCE_ID: source_notebook_id,
}

# With the presence of the above tags, the following is a get or create in which it will
if databricks_utils.is_in_databricks_repo_notebook():
tags[MLFLOW_EXPERIMENT_SOURCE_TYPE] = "REPO_NOTEBOOK"

# With the presence of the source id, the following is a get or create in which it will
# return the corresponding experiment if one exists for the repo notebook.
# If no corresponding experiment exist, it will create a new one and return
# the newly created experiment ID.
# For non-repo notebooks, it will raise an exception and we will use source_notebook_id
try:
experiment_id = MlflowClient().create_experiment(source_notebook_name, None, tags)
except MlflowException as e:
if e.error_code == databricks_pb2.ErrorCode.Name(
databricks_pb2.INVALID_PARAMETER_VALUE
):
# If repo notebook experiment creation isn't enabled, fall back to
# using the notebook ID
# If determined that it is not a repo notebook
experiment_id = source_notebook_id
else:
raise e
Comment on lines 31 to 40
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@annzhang-db I thought we would get a RESOURCE_ALREADY_EXISTS exception if we call create_experiment() with a non-repo notebook path that already exists. Is that what the backend does?

If the backend does indeed return RESOURCE_ALREADY_EXISTS, I think that would break the current implementation of DatabricksNotebookExperimentProvider in this PR; have we tested this thoroughly?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we get a RESOURCE_ALREADY_EXISTS exception if we call create_experiment() with a non-repo notebook path that already exists AND no sourceType/sourceId tags passed in. Since we are passing in the sourceId here, it will actually go into the previous case (in the backend PR, ill leave a comment there) and raise INVALID_PARAMETER_VALUE error for sourceId but no sourceType.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So if I log a param in a non-repo notebook, then attach/detach from the cluster, then try to log a param again, won't this default experiment provider try to call create_experiment() under the hood and then fail at the user level with RESOURCE_ALREADY_EXISTS, which will break the user's workflow?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It will call create_experiment() with a sourceId tag, which will not fail with RESOURCE_ALREADY_EXISTS. Only if create_experiment() is called without sourceId tag will it fail with RESOURCE_ALREADY_EXISTS.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, got it!


DatabricksRepoNotebookExperimentProvider._resolved_repo_notebook_experiment_id = (
experiment_id
)
DatabricksNotebookExperimentProvider._resolved_notebook_experiment_id = experiment_id

return experiment_id
6 changes: 1 addition & 5 deletions mlflow/tracking/default_experiment/registry.py
Expand Up @@ -6,17 +6,13 @@
from mlflow.tracking.default_experiment import DEFAULT_EXPERIMENT_ID
from mlflow.tracking.default_experiment.databricks_notebook_experiment_provider import (
DatabricksNotebookExperimentProvider,
DatabricksRepoNotebookExperimentProvider,
)

_logger = logging.getLogger(__name__)
# Listed below are the list of providers, which are used to provide MLflow Experiment IDs based on
# the current context where the MLflow client is running when the user has not explicitly set
# an experiment. The order below is the order in which the these providers are registered.
_EXPERIMENT_PROVIDERS = (
DatabricksRepoNotebookExperimentProvider,
DatabricksNotebookExperimentProvider,
)
_EXPERIMENT_PROVIDERS = (DatabricksNotebookExperimentProvider,)


class DefaultExperimentProviderRegistry:
Expand Down
Expand Up @@ -5,7 +5,6 @@
from mlflow.protos.databricks_pb2 import INVALID_PARAMETER_VALUE
from mlflow.tracking.default_experiment.databricks_notebook_experiment_provider import (
DatabricksNotebookExperimentProvider,
DatabricksRepoNotebookExperimentProvider,
)
from mlflow.utils.mlflow_tags import MLFLOW_EXPERIMENT_SOURCE_ID, MLFLOW_EXPERIMENT_SOURCE_TYPE

Expand All @@ -16,24 +15,22 @@ def test_databricks_notebook_default_experiment_in_context():


def test_databricks_notebook_default_experiment_id():
with mock.patch("mlflow.utils.databricks_utils.get_notebook_id") as patch_notebook_id:
with mock.patch.object(
MlflowClient,
"create_experiment",
side_effect=MlflowException(message="Error message", error_code=INVALID_PARAMETER_VALUE),
), mock.patch(
"mlflow.utils.databricks_utils.get_notebook_path",
return_value="path",
), mock.patch(
"mlflow.utils.databricks_utils.get_notebook_id"
) as patch_notebook_id:
assert (
DatabricksNotebookExperimentProvider().get_experiment_id()
== patch_notebook_id.return_value
)


def test_databricks_repo_notebook_default_experiment_in_context():
with mock.patch(
"mlflow.utils.databricks_utils.is_in_databricks_repo_notebook", return_value=True
):
assert DatabricksRepoNotebookExperimentProvider().in_context()
with mock.patch(
"mlflow.utils.databricks_utils.is_in_databricks_repo_notebook", return_value=False
):
assert not DatabricksRepoNotebookExperimentProvider().in_context()


def test_databricks_repo_notebook_default_experiment_gets_id_by_request():
with mock.patch(
"mlflow.utils.databricks_utils.get_notebook_id",
Expand All @@ -44,7 +41,8 @@ def test_databricks_repo_notebook_default_experiment_gets_id_by_request():
), mock.patch.object(
MlflowClient, "create_experiment", return_value="experiment_id"
) as create_experiment_mock:
returned_id = DatabricksRepoNotebookExperimentProvider().get_experiment_id()
DatabricksNotebookExperimentProvider._resolved_notebook_experiment_id = None
returned_id = DatabricksNotebookExperimentProvider().get_experiment_id()
assert returned_id == "experiment_id"
tags = {MLFLOW_EXPERIMENT_SOURCE_TYPE: "REPO_NOTEBOOK", MLFLOW_EXPERIMENT_SOURCE_ID: 1234}
create_experiment_mock.assert_called_once_with("/Repos/path", None, tags)
Expand All @@ -60,9 +58,9 @@ def test_databricks_repo_notebook_default_experiment_uses_fallback_notebook_id()
), mock.patch.object(
MlflowClient, "create_experiment"
) as create_experiment_mock:
DatabricksRepoNotebookExperimentProvider._resolved_repo_notebook_experiment_id = None
DatabricksNotebookExperimentProvider._resolved_notebook_experiment_id = None
create_experiment_mock.side_effect = MlflowException(
message="not enabled", error_code=INVALID_PARAMETER_VALUE
)
returned_id = DatabricksRepoNotebookExperimentProvider().get_experiment_id()
returned_id = DatabricksNotebookExperimentProvider().get_experiment_id()
assert returned_id == 1234