Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
Signed-off-by: Ann Zhang <ann.zhang@databricks.com>
  • Loading branch information
annzhang-db committed Nov 21, 2023
1 parent 8b69a60 commit 4f49f51
Show file tree
Hide file tree
Showing 2 changed files with 0 additions and 30 deletions.
@@ -1,5 +1,3 @@
import logging

from mlflow.exceptions import MlflowException
from mlflow.protos import databricks_pb2
from mlflow.tracking.client import MlflowClient
Expand All @@ -10,8 +8,6 @@
MLFLOW_EXPERIMENT_SOURCE_TYPE,
)

_logger = logging.getLogger(__name__)


class DatabricksNotebookExperimentProvider(DefaultExperimentProvider):
_resolved_notebook_experiment_id = None
Expand All @@ -20,8 +16,6 @@ def in_context(self):
return databricks_utils.is_in_databricks_notebook()

def get_experiment_id(self):
_logger.debug("get_experiment_id for DatabricksNotebookExperimentProvider")
print("get_experiment_id for DatabricksNotebookExperimentProvider")
if DatabricksNotebookExperimentProvider._resolved_notebook_experiment_id:
return DatabricksNotebookExperimentProvider._resolved_notebook_experiment_id

Expand All @@ -30,8 +24,6 @@ def get_experiment_id(self):
tags = {
MLFLOW_EXPERIMENT_SOURCE_ID: source_notebook_id,
}
print(f"source_notebook_id {source_notebook_id}")
print(f"source_notebook_name {source_notebook_name}")

# With the presence of the source id, the following is a get or create in which it will
# return the corresponding experiment if one exists for the repo notebook.
Expand All @@ -42,15 +34,12 @@ def get_experiment_id(self):
if e.error_code == databricks_pb2.ErrorCode.Name(
databricks_pb2.INVALID_PARAMETER_VALUE
):
print(f"it was not a repo notebook {e}")
# If determined that it is not a repo noetbook
experiment_id = source_notebook_id
else:
raise e

DatabricksNotebookExperimentProvider._resolved_notebook_experiment_id = experiment_id
_logger.debug(f"experiment_id = {experiment_id}")
print(f"experiment_id = {experiment_id}")

return experiment_id

Expand All @@ -62,8 +51,6 @@ def in_context(self):
return databricks_utils.is_in_databricks_repo_notebook()

def get_experiment_id(self):
_logger.debug("get_experiment_id for DatabricksREPONotebookExperimentProvider")
print("get_experiment_id for DatabricksREPONotebookExperimentProvider")
if DatabricksRepoNotebookExperimentProvider._resolved_repo_notebook_experiment_id:
return DatabricksRepoNotebookExperimentProvider._resolved_repo_notebook_experiment_id

Expand Down Expand Up @@ -93,6 +80,4 @@ def get_experiment_id(self):
DatabricksRepoNotebookExperimentProvider._resolved_repo_notebook_experiment_id = (
experiment_id
)
_logger.debug(f"experiment_id = {experiment_id}")
print(f"experiment_id = {experiment_id}")
return experiment_id
15 changes: 0 additions & 15 deletions mlflow/tracking/fluent.py
Expand Up @@ -293,11 +293,7 @@ def start_run(
global _active_run_stack
_validate_experiment_id_type(experiment_id)
# back compat for int experiment_id
_logger.debug("START RUN")
print("START RUN")
experiment_id = str(experiment_id) if isinstance(experiment_id, int) else experiment_id
_logger.debug(f"experiment_id = {experiment_id}")
print(f"experiment_id = {experiment_id}")
if len(_active_run_stack) > 0 and not nested:
raise Exception(
(
Expand All @@ -315,8 +311,6 @@ def start_run(
else:
existing_run_id = None
if existing_run_id:
_logger.debug("EXISTING RUN ID")
print("EXISTING RUN ID")
_validate_run_id(existing_run_id)
active_run_obj = client.get_run(existing_run_id)
# Check to see if experiment_id from environment matches experiment_id from set_experiment()
Expand Down Expand Up @@ -357,8 +351,6 @@ def start_run(
)
active_run_obj = client.get_run(existing_run_id)
else:
_logger.debug("ELSE")
print("ELSE")
parent_run_id = _active_run_stack[-1].info.run_id if len(_active_run_stack) > 0 else None

exp_id_for_run = experiment_id if experiment_id is not None else _get_experiment_id()
Expand All @@ -379,7 +371,6 @@ def start_run(

resolved_tags = context_registry.resolve_tags(user_specified_tags)

print(f"active_run_obj experiment_id {exp_id_for_run}")
active_run_obj = client.create_run(
experiment_id=exp_id_for_run,
tags=resolved_tags,
Expand Down Expand Up @@ -1919,15 +1910,9 @@ def _get_experiment_id_from_env():


def _get_experiment_id():
_logger.debug("call _get_experiment_id")
print("call _get_experiment_id")
if _active_experiment_id:
_logger.debug("return _active_experiment_id")
print("RETURN _active_experiment_id")
return _active_experiment_id
else:
_logger.debug("else case")
print("else case")
return _get_experiment_id_from_env() or default_experiment_registry.get_experiment_id()


Expand Down

0 comments on commit 4f49f51

Please sign in to comment.