Skip to content

Commit

Permalink
refactor: renaming tracking service.MLflowService to client.MLflowCli…
Browse files Browse the repository at this point in the history
…ent (#461)
  • Loading branch information
mparkhe authored and aarondav committed Sep 10, 2018
1 parent 179c47d commit d04b8fb
Show file tree
Hide file tree
Showing 10 changed files with 59 additions and 66 deletions.
12 changes: 6 additions & 6 deletions docs/source/tracking.rst
Expand Up @@ -166,19 +166,19 @@ Managing Experiments and Runs with the Tracking Service API
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

MLflow provides a more detailed Tracking Service API for managing experiments and runs directly,
which is available in the :py:mod:`mlflow.tracking` module.
which is available through client SDK in the :py:mod:`mlflow.tracking` module.
This makes it possible to query data about past runs, log additional information about them, create experiments and more.

Example usage:

.. code:: python
from mlflow.tracking import get_service
service = get_service()
from mlflow.tracking import MlflowClient
client = MlflowClient()
experiments = service.list_experiments() # returns a list of mlflow.entities.Experiment
run = service.create_run(experiments[0].experiment_id) # returns mlflow.entities.Run
service.log_param(run.info.run_uuid, "hello", "world")
service.set_terminated(run.info.run_uuid)
run = client.create_run(experiments[0].experiment_id) # returns mlflow.entities.Run
client.log_param(run.info.run_uuid, "hello", "world")
client.set_terminated(run.info.run_uuid)
.. _tracking_ui:

Expand Down
10 changes: 5 additions & 5 deletions examples/multistep_workflow/main.py
Expand Up @@ -28,13 +28,13 @@ def _already_ran(entry_point_name, parameters, source_version, experiment_id=Non
successfully and have at least the parameters provided.
"""
experiment_id = experiment_id if experiment_id is not None else _get_experiment_id()
service = mlflow.tracking.get_service()
all_run_infos = reversed(service.list_run_infos(experiment_id))
client = mlflow.tracking.MlflowClient()
all_run_infos = reversed(client.list_run_infos(experiment_id))
for run_info in all_run_infos:
if run_info.entry_point_name != entry_point_name:
continue

full_run = service.get_run(run_info.run_uuid)
full_run = client.get_run(run_info.run_uuid)
run_params = _get_params(full_run)
match_failed = False
for param_key, param_value in six.iteritems(parameters):
Expand All @@ -53,7 +53,7 @@ def _already_ran(entry_point_name, parameters, source_version, experiment_id=Non
eprint(("Run matched, but has a different source version, so skipping "
"(found=%s, expected=%s)") % (run_info.source_version, source_version))
continue
return service.get_run(run_info.run_uuid)
return client.get_run(run_info.run_uuid)
return None


Expand All @@ -67,7 +67,7 @@ def _get_or_run(entrypoint, parameters, source_version, use_cache=True):
return existing_run
print("Launching new run for entrypoint=%s and parameters=%s" % (entrypoint, parameters))
submitted_run = mlflow.run(".", entrypoint, parameters=parameters)
return mlflow.tracking.get_service().get_run(submitted_run.run_id)
return mlflow.tracking.MlflowClient().get_run(submitted_run.run_id)


@click.command()
Expand Down
2 changes: 1 addition & 1 deletion examples/remote_store/remote_server.py
Expand Up @@ -19,7 +19,7 @@
log_metric("random_int", random.randint(0, 100))
run_uuid = active_run().info.run_uuid
# Get run metadata & data from the tracking server
service = mlflow.tracking.get_service()
service = mlflow.tracking.MlflowClient()
run = service.get_run(run_uuid)
print("Metadata & data for run with UUID %s: %s" % (run_uuid, run))
local_dir = tempfile.mkdtemp()
Expand Down
12 changes: 6 additions & 6 deletions mlflow/projects/__init__.py
Expand Up @@ -44,7 +44,7 @@ def _run(uri, entry_point="main", version=None, parameters=None, experiment_id=N
project = _project_spec.load_project(work_dir)
project.get_entry_point(entry_point)._validate_parameters(parameters)
if run_id:
active_run = tracking.get_service().get_run(run_id)
active_run = tracking.MlflowClient().get_run(run_id)
else:
active_run = _create_run(uri, exp_id, work_dir, entry_point)

Expand All @@ -53,7 +53,7 @@ def _run(uri, entry_point="main", version=None, parameters=None, experiment_id=N
entry_point_obj = project.get_entry_point(entry_point)
final_params, extra_params = entry_point_obj.compute_parameters(parameters, storage_dir=None)
for key, value in (list(final_params.items()) + list(extra_params.items())):
tracking.get_service().log_param(active_run.info.run_uuid, key, value)
tracking.MlflowClient().log_param(active_run.info.run_uuid, key, value)

if mode == "databricks":
from mlflow.projects.databricks import run_databricks
Expand Down Expand Up @@ -145,7 +145,7 @@ def _wait_for(submitted_run_obj):
# Note: there's a small chance we fail to report the run's status to the tracking server if
# we're interrupted before we reach the try block below
try:
active_run = tracking.get_service().get_run(run_id) if run_id is not None else None
active_run = tracking.MlflowClient().get_run(run_id) if run_id is not None else None
if submitted_run_obj.wait():
eprint("=== Run (ID '%s') succeeded ===" % run_id)
_maybe_set_run_terminated(active_run, "FINISHED")
Expand Down Expand Up @@ -303,10 +303,10 @@ def _maybe_set_run_terminated(active_run, status):
if active_run is None:
return
run_id = active_run.info.run_uuid
cur_status = tracking.get_service().get_run(run_id).info.status
cur_status = tracking.MlflowClient().get_run(run_id).info.status
if RunStatus.is_terminated(cur_status):
return
tracking.get_service().set_terminated(run_id, status)
tracking.MlflowClient().set_terminated(run_id, status)


def _get_entry_point_command(project, entry_point, parameters, conda_env_name, storage_dir):
Expand Down Expand Up @@ -386,7 +386,7 @@ def _create_run(uri, experiment_id, work_dir, entry_point):
source_name = tracking.utils._get_git_url_if_present(_expand_uri(uri))
else:
source_name = _expand_uri(uri)
active_run = tracking.get_service().create_run(
active_run = tracking.MlflowClient().create_run(
experiment_id=experiment_id,
source_name=source_name,
source_version=_get_git_commit(work_dir),
Expand Down
16 changes: 8 additions & 8 deletions mlflow/projects/databricks.py
Expand Up @@ -310,18 +310,18 @@ def _print_description_and_log_tags(self):
jobs_page_url = run_info["run_page_url"]
eprint("=== Check the run's status at %s ===" % jobs_page_url)
host_creds = databricks_utils.get_databricks_host_creds(self._job_runner.databricks_profile)
tracking.get_service().set_tag(self._mlflow_run_id,
MLFLOW_DATABRICKS_RUN_URL, jobs_page_url)
tracking.get_service().set_tag(self._mlflow_run_id,
MLFLOW_DATABRICKS_SHELL_JOB_RUN_ID, self._databricks_run_id)
tracking.get_service().set_tag(self._mlflow_run_id,
MLFLOW_DATABRICKS_WEBAPP_URL, host_creds.host)
tracking.MlflowClient().set_tag(self._mlflow_run_id,
MLFLOW_DATABRICKS_RUN_URL, jobs_page_url)
tracking.MlflowClient().set_tag(self._mlflow_run_id,
MLFLOW_DATABRICKS_SHELL_JOB_RUN_ID, self._databricks_run_id)
tracking.MlflowClient().set_tag(self._mlflow_run_id,
MLFLOW_DATABRICKS_WEBAPP_URL, host_creds.host)
job_id = run_info.get('job_id')
# In some releases of Databricks we do not return the job ID. We start including it in DB
# releases 2.80 and above.
if job_id is not None:
tracking.get_service().set_tag(self._mlflow_run_id,
MLFLOW_DATABRICKS_SHELL_JOB_ID, job_id)
tracking.MlflowClient().set_tag(self._mlflow_run_id,
MLFLOW_DATABRICKS_SHELL_JOB_ID, job_id)

@property
def run_id(self):
Expand Down
5 changes: 2 additions & 3 deletions mlflow/tracking/__init__.py
Expand Up @@ -5,14 +5,13 @@
For a higher level API for managing an "active run", use the :py:mod:`mlflow` module.
"""

from mlflow.tracking.service import MLflowService, get_service
from mlflow.tracking.client import MlflowClient
from mlflow.tracking.utils import set_tracking_uri, get_tracking_uri, _get_store, \
_TRACKING_URI_ENV_VAR
from mlflow.tracking.fluent import _EXPERIMENT_ID_ENV_VAR, _RUN_ID_ENV_VAR

__all__ = [
"MLflowService",
"get_service",
"MlflowClient",
"get_tracking_uri",
"set_tracking_uri",
"_get_store",
Expand Down
26 changes: 10 additions & 16 deletions mlflow/tracking/service.py → mlflow/tracking/client.py
Expand Up @@ -17,12 +17,19 @@
_DEFAULT_USER_ID = "unknown"


class MLflowService(object):
class MlflowClient(object):
"""Client of an MLflow Tracking Server that creates and manages experiments and runs.
"""

def __init__(self, store):
self.store = store
def __init__(self, tracking_uri=None):
"""
:param tracking_uri: Address of local or remote tracking server. If not provided, defaults
to the service set by ``mlflow.tracking.set_tracking_uri``. See
`Where Runs Get Recorded <../tracking.html#where-runs-get-recorded>`_
for more info.
"""
self.tracking_uri = tracking_uri
self.store = _get_store(tracking_uri)

def get_run(self, run_id):
""":return: :py:class:`mlflow.entities.Run` associated with the run ID."""
Expand Down Expand Up @@ -197,19 +204,6 @@ def set_terminated(self, run_id, status=None, end_time=None):
end_time=end_time)


def get_service(tracking_uri=None):
"""
Get the tracking service.
:param tracking_uri: Address of local or remote tracking server. If not provided,
this defaults to the service set by ``mlflow.tracking.set_tracking_uri``. See
`Where Runs Get Recorded <../tracking.html#where-runs-get-recorded>`_ for more info.
:return: :py:class:`mlflow.tracking.MLflowService`
"""
store = _get_store(tracking_uri)
return MLflowService(store)


def _get_user_id():
"""Get the ID of the user for the current run."""
try:
Expand Down
22 changes: 11 additions & 11 deletions mlflow/tracking/fluent.py
Expand Up @@ -19,7 +19,7 @@
MLFLOW_DATABRICKS_NOTEBOOK_PATH, \
MLFLOW_DATABRICKS_NOTEBOOK_ID
from mlflow.utils.validation import _validate_run_id
from mlflow.tracking.service import get_service
from mlflow.tracking.client import MlflowClient


_EXPERIMENT_ID_ENV_VAR = "MLFLOW_EXPERIMENT_ID"
Expand Down Expand Up @@ -77,7 +77,7 @@ def start_run(run_uuid=None, experiment_id=None, source_name=None, source_versio
existing_run_uuid = run_uuid or os.environ.get(_RUN_ID_ENV_VAR, None)
if existing_run_uuid:
_validate_run_id(existing_run_uuid)
active_run_obj = get_service().get_run(existing_run_uuid)
active_run_obj = MlflowClient().get_run(existing_run_uuid)
else:
exp_id_for_run = experiment_id or _get_experiment_id()
if is_in_databricks_notebook():
Expand All @@ -91,7 +91,7 @@ def start_run(run_uuid=None, experiment_id=None, source_name=None, source_versio
databricks_tags[MLFLOW_DATABRICKS_NOTEBOOK_PATH] = notebook_path
if webapp_url is not None:
databricks_tags[MLFLOW_DATABRICKS_WEBAPP_URL] = webapp_url
active_run_obj = get_service().create_run(
active_run_obj = MlflowClient().create_run(
experiment_id=exp_id_for_run,
run_name=run_name,
source_name=notebook_path,
Expand All @@ -100,7 +100,7 @@ def start_run(run_uuid=None, experiment_id=None, source_name=None, source_versio
source_type=SourceType.NOTEBOOK,
tags=databricks_tags)
else:
active_run_obj = get_service().create_run(
active_run_obj = MlflowClient().create_run(
experiment_id=exp_id_for_run,
run_name=run_name,
source_name=source_name or _get_source_name(),
Expand All @@ -115,7 +115,7 @@ def end_run(status="FINISHED"):
"""End an active MLflow run (if there is one)."""
global _active_run
if _active_run:
get_service().set_terminated(_active_run.info.run_uuid, status)
MlflowClient().set_terminated(_active_run.info.run_uuid, status)
# Clear out the global existing run environment variable as well.
env.unset_variable(_RUN_ID_ENV_VAR)
_active_run = None
Expand All @@ -137,7 +137,7 @@ def log_param(key, value):
:param value: Parameter value (string, but will be string-ified if not)
"""
run_id = _get_or_start_run().info.run_uuid
get_service().log_param(run_id, key, value)
MlflowClient().log_param(run_id, key, value)


def set_tag(key, value):
Expand All @@ -148,7 +148,7 @@ def set_tag(key, value):
:param value: Tag value (string, but will be string-ified if not)
"""
run_id = _get_or_start_run().info.run_uuid
get_service().set_tag(run_id, key, value)
MlflowClient().set_tag(run_id, key, value)


def log_metric(key, value):
Expand All @@ -163,7 +163,7 @@ def log_metric(key, value):
key, value), file=sys.stderr)
return
run_id = _get_or_start_run().info.run_uuid
get_service().log_metric(run_id, key, value, int(time.time()))
MlflowClient().log_metric(run_id, key, value, int(time.time()))


def log_artifact(local_path, artifact_path=None):
Expand All @@ -174,7 +174,7 @@ def log_artifact(local_path, artifact_path=None):
:param artifact_path: If provided, the directory in ``artifact_uri`` to write to.
"""
run_id = _get_or_start_run().info.run_uuid
get_service().log_artifact(run_id, local_path, artifact_path)
MlflowClient().log_artifact(run_id, local_path, artifact_path)


def log_artifacts(local_dir, artifact_path=None):
Expand All @@ -185,7 +185,7 @@ def log_artifacts(local_dir, artifact_path=None):
:param artifact_path: If provided, the directory in ``artifact_uri`` to write to.
"""
run_id = _get_or_start_run().info.run_uuid
get_service().log_artifacts(run_id, local_dir, artifact_path)
MlflowClient().log_artifacts(run_id, local_dir, artifact_path)


def create_experiment(name, artifact_location=None):
Expand All @@ -197,7 +197,7 @@ def create_experiment(name, artifact_location=None):
If not provided, the server picks an appropriate default.
:return: Integer ID of the created experiment.
"""
return get_service().create_experiment(name, artifact_location)
return MlflowClient().create_experiment(name, artifact_location)


def get_artifact_uri():
Expand Down
6 changes: 3 additions & 3 deletions tests/projects/test_databricks.py
Expand Up @@ -12,7 +12,7 @@
from mlflow.projects.databricks import DatabricksJobRunner
from mlflow.entities import RunStatus
from mlflow.projects import databricks, ExecutionException
from mlflow.tracking import get_service
from mlflow.tracking import MlflowClient
from mlflow.utils import file_utils
from mlflow.utils.mlflow_tags import MLFLOW_DATABRICKS_RUN_URL, \
MLFLOW_DATABRICKS_SHELL_JOB_RUN_ID, \
Expand Down Expand Up @@ -93,8 +93,8 @@ def before_run_validations_mock(): # pylint: disable=unused-argument

@pytest.fixture()
def set_tag_mock():
with mock.patch("mlflow.projects.databricks.tracking.get_service") as m:
mlflow_service_mock = mock.Mock(wraps=get_service())
with mock.patch("mlflow.projects.databricks.tracking.MlflowClient") as m:
mlflow_service_mock = mock.Mock(wraps=MlflowClient())
m.return_value = mlflow_service_mock
yield mlflow_service_mock.set_tag

Expand Down
14 changes: 7 additions & 7 deletions tests/tracking/test_tracking.py
Expand Up @@ -52,10 +52,10 @@ def test_start_run_context_manager():
first_uuid = first_run.info.run_uuid
with first_run:
# Check that start_run() causes the run information to be persisted in the store
persisted_run = tracking.get_service().get_run(first_uuid)
persisted_run = tracking.MlflowClient().get_run(first_uuid)
assert persisted_run is not None
assert persisted_run.info == first_run.info
finished_run = tracking.get_service().get_run(first_uuid)
finished_run = tracking.MlflowClient().get_run(first_uuid)
assert finished_run.info.status == RunStatus.FINISHED
# Launch a separate run that fails, verify the run status is FAILED and the run UUID is
# different
Expand All @@ -64,7 +64,7 @@ def test_start_run_context_manager():
with pytest.raises(Exception):
with second_run:
raise Exception("Failing run!")
finished_run2 = tracking.get_service().get_run(second_run.info.run_uuid)
finished_run2 = tracking.MlflowClient().get_run(second_run.info.run_uuid)
assert finished_run2.info.status == RunStatus.FAILED
finally:
tracking.set_tracking_uri(None)
Expand All @@ -77,7 +77,7 @@ def test_start_and_end_run():
active_run = start_run()
mlflow.log_metric("name_1", 25)
end_run()
finished_run = tracking.get_service().get_run(active_run.info.run_uuid)
finished_run = tracking.MlflowClient().get_run(active_run.info.run_uuid)
# Validate metrics
assert len(finished_run.data.metrics) == 1
expected_pairs = {"name_1": 25}
Expand All @@ -97,7 +97,7 @@ def test_log_metric():
mlflow.log_metric("name_2", -3)
mlflow.log_metric("name_1", 30)
mlflow.log_metric("nested/nested/name", 40)
finished_run = tracking.get_service().get_run(run_uuid)
finished_run = tracking.MlflowClient().get_run(run_uuid)
# Validate metrics
assert len(finished_run.data.metrics) == 3
expected_pairs = {"name_1": 30, "name_2": -3, "nested/nested/name": 40}
Expand All @@ -114,7 +114,7 @@ def test_log_metric_validation():
run_uuid = active_run.info.run_uuid
with active_run:
mlflow.log_metric("name_1", "apple")
finished_run = tracking.get_service().get_run(run_uuid)
finished_run = tracking.MlflowClient().get_run(run_uuid)
assert len(finished_run.data.metrics) == 0
finally:
tracking.set_tracking_uri(None)
Expand All @@ -130,7 +130,7 @@ def test_log_param():
mlflow.log_param("name_2", "b")
mlflow.log_param("name_1", "c")
mlflow.log_param("nested/nested/name", 5)
finished_run = tracking.get_service().get_run(run_uuid)
finished_run = tracking.MlflowClient().get_run(run_uuid)
# Validate params
assert len(finished_run.data.params) == 3
expected_pairs = {"name_1": "c", "name_2": "b", "nested/nested/name": "5"}
Expand Down

0 comments on commit d04b8fb

Please sign in to comment.