Skip to content

Commit

Permalink
Remove arguments to MlflowClient.create_run that are now in tags (m…
Browse files Browse the repository at this point in the history
  • Loading branch information
acroz authored and Eduardo de Leon committed Mar 13, 2019
1 parent bf771be commit 9020aa9
Show file tree
Hide file tree
Showing 9 changed files with 233 additions and 127 deletions.
24 changes: 24 additions & 0 deletions mlflow/java/client/src/main/java/org/mlflow/api/proto/Service.java

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

35 changes: 19 additions & 16 deletions mlflow/projects/__init__.py
Expand Up @@ -14,26 +14,23 @@
import tempfile
import logging
import posixpath
import six
import docker

import mlflow.tracking as tracking
import mlflow.tracking.fluent as fluent
from mlflow.projects.submitted_run import LocalSubmittedRun, SubmittedRun
from mlflow.projects import _project_spec
from mlflow.exceptions import ExecutionException, MlflowException
from mlflow.entities import RunStatus, SourceType, Param
from mlflow.entities import RunStatus, SourceType
from mlflow.tracking.fluent import _get_experiment_id
from mlflow.tracking.context import _get_git_commit

import mlflow.projects.databricks
from mlflow.utils import process
from mlflow.utils.mlflow_tags import MLFLOW_GIT_REPO_URL, MLFLOW_GIT_BRANCH, \
LEGACY_MLFLOW_GIT_REPO_URL, LEGACY_MLFLOW_GIT_BRANCH_NAME
from mlflow.utils.mlflow_tags import MLFLOW_PROJECT_ENV
from mlflow.utils.mlflow_tags import MLFLOW_DOCKER_IMAGE_NAME, MLFLOW_DOCKER_IMAGE_ID
from mlflow.utils.mlflow_tags import MLFLOW_PROJECT_ENV, MLFLOW_DOCKER_IMAGE_NAME, \
MLFLOW_DOCKER_IMAGE_ID, MLFLOW_SOURCE_NAME, MLFLOW_SOURCE_TYPE, MLFLOW_GIT_COMMIT, \
MLFLOW_GIT_REPO_URL, MLFLOW_GIT_BRANCH, LEGACY_MLFLOW_GIT_REPO_URL, \
LEGACY_MLFLOW_GIT_BRANCH_NAME, MLFLOW_PROJECT_ENTRY_POINT, MLFLOW_PARENT_RUN_ID
from mlflow.utils import databricks_utils, file_utils
from mlflow.utils.logging_utils import eprint
import docker

# TODO: this should be restricted to just Git repos and not S3 and stuff like that
_GIT_URI_REGEX = re.compile(r"^[^/]*:")
Expand Down Expand Up @@ -541,18 +538,24 @@ def _create_run(uri, experiment_id, work_dir, entry_point):
source_name = tracking.utils._get_git_url_if_present(_expand_uri(uri))
else:
source_name = _expand_uri(uri)
source_version = _get_git_commit(work_dir)
existing_run = fluent.active_run()
if existing_run:
parent_run_id = existing_run.info.run_uuid
else:
parent_run_id = None
active_run = tracking.MlflowClient().create_run(
experiment_id=experiment_id,
source_name=source_name,
source_version=_get_git_commit(work_dir),
entry_point_name=entry_point,
source_type=SourceType.PROJECT,
parent_run_id=parent_run_id)

tags = {
MLFLOW_SOURCE_NAME: source_name,
MLFLOW_SOURCE_TYPE: SourceType.to_string(SourceType.PROJECT),
MLFLOW_PROJECT_ENTRY_POINT: entry_point
}
if source_version is not None:
tags[MLFLOW_GIT_COMMIT] = source_version
if parent_run_id is not None:
tags[MLFLOW_PARENT_RUN_ID] = parent_run_id

active_run = tracking.MlflowClient().create_run(experiment_id=experiment_id, tags=tags)
return active_run


Expand Down
2 changes: 2 additions & 0 deletions mlflow/protos/service.proto
Expand Up @@ -584,6 +584,8 @@ message CreateRun {
repeated RunTag tags = 9;

// ID of the parent run which started this run.
// This field is deprecated and will be removed in MLflow 1.0. Use the ``mlflow.parentRunId`` run
// tag instead.
optional string parent_run_id = 10;

message Response {
Expand Down
31 changes: 24 additions & 7 deletions mlflow/tracking/client.py
Expand Up @@ -13,6 +13,8 @@
_validate_experiment_name, _validate_metric
from mlflow.entities import Param, Metric, RunStatus, RunTag, ViewType, SourceType
from mlflow.store.artifact_repository_registry import get_artifact_repository
from mlflow.utils.mlflow_tags import MLFLOW_SOURCE_NAME, MLFLOW_SOURCE_TYPE, MLFLOW_PARENT_RUN_ID, \
MLFLOW_GIT_COMMIT, MLFLOW_PROJECT_ENTRY_POINT

_DEFAULT_USER_ID = "unknown"

Expand All @@ -36,9 +38,7 @@ def get_run(self, run_id):
_validate_run_id(run_id)
return self.store.get_run(run_id)

def create_run(self, experiment_id, user_id=None, run_name=None, source_type=None,
source_name=None, entry_point_name=None, start_time=None,
source_version=None, tags=None, parent_run_id=None):
def create_run(self, experiment_id, user_id=None, run_name=None, start_time=None, tags=None):
"""
Create a :py:class:`mlflow.entities.Run` object that can be associated with
metrics, parameters, artifacts, etc.
Expand All @@ -52,18 +52,35 @@ def create_run(self, experiment_id, user_id=None, run_name=None, source_type=Non
:py:class:`mlflow.entities.RunTag` objects.
:return: :py:class:`mlflow.entities.Run` that was created.
"""

tags = tags if tags else {}

# Extract run attributes from tags
# This logic is temporary; by the 1.0 release, this information will only be stored in tags
# and will not be available as attributes of the run
parent_run_id = tags.get(MLFLOW_PARENT_RUN_ID)
source_name = tags.get(MLFLOW_SOURCE_NAME, "Python Application")
source_version = tags.get(MLFLOW_GIT_COMMIT)
entry_point_name = tags.get(MLFLOW_PROJECT_ENTRY_POINT)

source_type_string = tags.get(MLFLOW_SOURCE_TYPE)
if source_type_string is None:
source_type = SourceType.LOCAL
else:
source_type = SourceType.from_string(source_type_string)

return self.store.create_run(
experiment_id=experiment_id,
user_id=user_id if user_id is not None else _get_user_id(),
run_name=run_name,
source_type=source_type if source_type is not None else SourceType.LOCAL,
source_name=source_name if source_name is not None else "Python Application",
entry_point_name=entry_point_name,
start_time=start_time or int(time.time() * 1000),
source_version=source_version,
tags=[RunTag(key, value) for (key, value) in iteritems(tags)],
# The below arguments remain set for backwards compatability:
parent_run_id=parent_run_id,
source_type=source_type,
source_name=source_name,
entry_point_name=entry_point_name,
source_version=source_version
)

def list_run_infos(self, experiment_id, run_view_type=ViewType.ACTIVE_ONLY):
Expand Down
15 changes: 4 additions & 11 deletions mlflow/tracking/fluent.py
Expand Up @@ -20,7 +20,7 @@
from mlflow.utils import env
from mlflow.utils.databricks_utils import is_in_databricks_notebook, get_notebook_id
from mlflow.utils.mlflow_tags import MLFLOW_GIT_COMMIT, MLFLOW_SOURCE_TYPE, MLFLOW_SOURCE_NAME, \
MLFLOW_PROJECT_ENTRY_POINT
MLFLOW_PROJECT_ENTRY_POINT, MLFLOW_PARENT_RUN_ID
from mlflow.utils.validation import _validate_run_id

_EXPERIMENT_ID_ENV_VAR = "MLFLOW_EXPERIMENT_ID"
Expand Down Expand Up @@ -123,6 +123,8 @@ def start_run(run_uuid=None, experiment_id=None, source_name=None, source_versio
exp_id_for_run = experiment_id if experiment_id is not None else _get_experiment_id()

user_specified_tags = {}
if parent_run_id is not None:
user_specified_tags[MLFLOW_PARENT_RUN_ID] = parent_run_id
if source_name is not None:
user_specified_tags[MLFLOW_SOURCE_NAME] = source_name
if source_type is not None:
Expand All @@ -134,19 +136,10 @@ def start_run(run_uuid=None, experiment_id=None, source_name=None, source_versio

tags = context.resolve_tags(user_specified_tags)

# Polling resolved tags for run meta data : source_name, source_version,
# entry_point_name, and source_type which is store in RunInfo for backward compatibility.
# TODO: Remove all 4 of the following annotated backward compatibility fixes with API
# changes to create_run.
active_run_obj = MlflowClient().create_run(
experiment_id=exp_id_for_run,
run_name=run_name,
source_name=tags.get(MLFLOW_SOURCE_NAME), # TODO: for backward compatibility. Remove.
source_version=tags.get(MLFLOW_GIT_COMMIT), # TODO: for backward compatibility. Remove.
entry_point_name=tags.get(MLFLOW_PROJECT_ENTRY_POINT), # TODO: remove
source_type=SourceType.from_string(tags.get(MLFLOW_SOURCE_TYPE)), # TODO: Remove
tags=tags,
parent_run_id=parent_run_id
tags=tags
)

_active_run_stack.append(ActiveRun(active_run_obj))
Expand Down
61 changes: 37 additions & 24 deletions tests/projects/test_projects.py
Expand Up @@ -8,11 +8,12 @@
import pytest

import mlflow
from mlflow.entities import RunStatus, ViewType
from mlflow.entities import RunStatus, ViewType, SourceType
from mlflow.exceptions import ExecutionException
from mlflow.utils import env
from mlflow.utils.mlflow_tags import MLFLOW_PARENT_RUN_ID, MLFLOW_GIT_BRANCH, MLFLOW_GIT_REPO_URL, \
LEGACY_MLFLOW_GIT_BRANCH_NAME, LEGACY_MLFLOW_GIT_REPO_URL
from mlflow.utils.mlflow_tags import MLFLOW_PARENT_RUN_ID, MLFLOW_SOURCE_NAME, MLFLOW_SOURCE_TYPE, \
MLFLOW_GIT_BRANCH, MLFLOW_GIT_REPO_URL, LEGACY_MLFLOW_GIT_BRANCH_NAME, \
LEGACY_MLFLOW_GIT_REPO_URL, MLFLOW_PROJECT_ENTRY_POINT

from tests.projects.utils import TEST_PROJECT_DIR, TEST_PROJECT_NAME, GIT_PROJECT_URI, \
validate_exit_status, assert_dirs_equal
Expand All @@ -30,7 +31,7 @@ def _get_version_local_git_repo(local_git_repo):
return repo.git.rev_parse("HEAD")


@pytest.fixture()
@pytest.fixture
def local_git_repo(tmpdir):
local_git = tmpdir.join('git_repo').strpath
repo = git.Repo.init(local_git)
Expand All @@ -41,12 +42,12 @@ def local_git_repo(tmpdir):
yield os.path.abspath(local_git)


@pytest.fixture()
@pytest.fixture
def local_git_repo_uri(local_git_repo):
return "file://%s" % local_git_repo


@pytest.fixture()
@pytest.fixture
def zipped_repo(tmpdir):
import zipfile
zip_name = tmpdir.join('%s.zip' % TEST_PROJECT_NAME).strpath
Expand Down Expand Up @@ -204,18 +205,23 @@ def test_run_local_git_repo(local_git_repo,
store_run_uuid = run_infos[0].run_uuid
assert run_uuid == store_run_uuid
run = mlflow_service.get_run(run_uuid)
expected_params = {"use_start_run": use_start_run}

assert run.info.status == RunStatus.FINISHED
assert len(run.data.params) == len(expected_params)
for param in run.data.params:
assert param.value == expected_params[param.key]

expected_params = {"use_start_run": use_start_run}
params = {param.key: param.value for param in run.data.params}
assert params == expected_params

expected_metrics = {"some_key": 3}
assert len(run.data.metrics) == len(expected_metrics)
for metric in run.data.metrics:
assert metric.value == expected_metrics[metric.key]
# Validate the branch name tag is logged
metrics = {metric.key: metric.value for metric in run.data.metrics}
assert metrics == expected_metrics

tags = {tag.key: tag.value for tag in run.data.tags}
assert "file:" in tags[MLFLOW_SOURCE_NAME]
assert tags[MLFLOW_SOURCE_TYPE] == SourceType.to_string(SourceType.PROJECT)
assert tags[MLFLOW_PROJECT_ENTRY_POINT] == "test_tracking"

if version == "master":
tags = {tag.key: tag.value for tag in run.data.tags}
assert tags[MLFLOW_GIT_BRANCH] == "master"
assert tags[MLFLOW_GIT_REPO_URL] == local_git_repo_uri
assert tags[LEGACY_MLFLOW_GIT_BRANCH_NAME] == "master"
Expand Down Expand Up @@ -248,20 +254,27 @@ def test_run(tmpdir, tracking_uri_mock, use_start_run): # pylint: disable=unuse
# Validate run contents in the FileStore
run_uuid = submitted_run.run_id
mlflow_service = mlflow.tracking.MlflowClient()

run_infos = mlflow_service.list_run_infos(experiment_id=0, run_view_type=ViewType.ACTIVE_ONLY)
assert len(run_infos) == 1
store_run_uuid = run_infos[0].run_uuid
assert run_uuid == store_run_uuid
run = mlflow_service.get_run(run_uuid)
expected_params = {"use_start_run": use_start_run}

assert run.info.status == RunStatus.FINISHED
assert len(run.data.params) == len(expected_params)
for param in run.data.params:
assert param.value == expected_params[param.key]

expected_params = {"use_start_run": use_start_run}
params = {param.key: param.value for param in run.data.params}
assert params == expected_params

expected_metrics = {"some_key": 3}
assert len(run.data.metrics) == len(expected_metrics)
for metric in run.data.metrics:
assert metric.value == expected_metrics[metric.key]
metrics = {metric.key: metric.value for metric in run.data.metrics}
assert metrics == expected_metrics

tags = {tag.key: tag.value for tag in run.data.tags}
assert "file:" in tags[MLFLOW_SOURCE_NAME]
assert tags[MLFLOW_SOURCE_TYPE] == SourceType.to_string(SourceType.PROJECT)
assert tags[MLFLOW_PROJECT_ENTRY_POINT] == "test_tracking"


def test_run_with_parent(tmpdir, tracking_uri_mock): # pylint: disable=unused-argument
Expand All @@ -276,8 +289,8 @@ def test_run_with_parent(tmpdir, tracking_uri_mock): # pylint: disable=unused-a
validate_exit_status(submitted_run.get_status(), RunStatus.FINISHED)
run_uuid = submitted_run.run_id
run = mlflow.tracking.MlflowClient().get_run(run_uuid)
parent_run_id_tag = [tag.value for tag in run.data.tags if tag.key == MLFLOW_PARENT_RUN_ID]
assert parent_run_id_tag == [parent_run_id]
tags = {tag.key: tag.value for tag in run.data.tags}
assert tags[MLFLOW_PARENT_RUN_ID] == parent_run_id


def test_run_async(tracking_uri_mock): # pylint: disable=unused-argument
Expand Down

0 comments on commit 9020aa9

Please sign in to comment.