Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[#297] Complete Python SDK #299

Merged
merged 4 commits into from Aug 17, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
8 changes: 8 additions & 0 deletions docs/source/python_api/mlflow.entities.rst
@@ -0,0 +1,8 @@
mlflow.entities
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should probably also update docs/source/tracking.rst to describe the two APIs, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Documentation is important, and I would like to get that closely reviewed. Let me do that in a follow-up PR which I'll go ahead and draft.

===============

.. automodule:: mlflow.entities
:members:
:undoc-members:
:show-inheritance:

46 changes: 36 additions & 10 deletions mlflow/__init__.py
@@ -1,22 +1,48 @@
"""
Provides the MLflow fluent API, allowing management of an active MLflow run.
For example:

.. code:: python

import mlflow
mlflow.start_run()
mlflow.log_param("my", "param")
mlflow.log_metric("score", 100)
mlflow.end_run()

You can also use syntax like this:

.. code:: python

with mlflow.start_run() as run:
...

which will automatically terminate the run at the end of the block.
"""

import os

# pylint: disable=wrong-import-position
import mlflow.projects as projects # noqa
import mlflow.tracking as tracking # noqa
import mlflow.tracking.fluent

log_param = tracking.log_param
log_metric = tracking.log_metric
log_artifacts = tracking.log_artifacts
log_artifact = tracking.log_artifact
active_run = tracking.active_run
start_run = tracking.start_run
end_run = tracking.end_run
get_artifact_uri = tracking.get_artifact_uri
ActiveRun = mlflow.tracking.fluent.ActiveRun
log_param = mlflow.tracking.fluent.log_param
log_metric = mlflow.tracking.fluent.log_metric
log_artifacts = mlflow.tracking.fluent.log_artifacts
log_artifact = mlflow.tracking.fluent.log_artifact
active_run = mlflow.tracking.fluent.active_run
start_run = mlflow.tracking.fluent.start_run
end_run = mlflow.tracking.fluent.end_run
get_artifact_uri = mlflow.tracking.fluent.get_artifact_uri
set_tracking_uri = tracking.set_tracking_uri
get_tracking_uri = tracking.get_tracking_uri
create_experiment = tracking.create_experiment
create_experiment = mlflow.tracking.fluent.create_experiment


run = projects.run

__all__ = ["log_param", "log_metric", "log_artifacts", "log_artifact", "active_run",

__all__ = ["ActiveRun", "log_param", "log_metric", "log_artifacts", "log_artifact", "active_run",
"start_run", "end_run", "get_artifact_uri", "set_tracking_uri", "create_experiment"]
2 changes: 1 addition & 1 deletion mlflow/azureml/__init__.py
Expand Up @@ -7,7 +7,7 @@
import mlflow
from mlflow import pyfunc
from mlflow.models import Model
from mlflow.tracking import _get_model_log_dir
from mlflow.tracking.utils import _get_model_log_dir
from mlflow.utils.logging_utils import eprint
from mlflow.utils.file_utils import TempDir
from mlflow.version import VERSION as mlflow_version
Expand Down
25 changes: 25 additions & 0 deletions mlflow/entities/__init__.py
@@ -0,0 +1,25 @@
"""All entities returned by the MLflow REST API."""

from mlflow.entities.experiment import Experiment
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could add a doc comment here too

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

from mlflow.entities.file_info import FileInfo
from mlflow.entities.metric import Metric
from mlflow.entities.param import Param
from mlflow.entities.run import Run
from mlflow.entities.run_data import RunData
from mlflow.entities.run_info import RunInfo
from mlflow.entities.run_status import RunStatus
from mlflow.entities.run_tag import RunTag
from mlflow.entities.source_type import SourceType

__all__ = [
"Experiment",
"FileInfo",
"Metric",
"Param",
"Run",
"RunData",
"RunInfo",
"RunStatus",
"RunTag",
"SourceType",
]
2 changes: 1 addition & 1 deletion mlflow/h2o.py
Expand Up @@ -91,5 +91,5 @@ def load_model(path, run_id=None):
`h2o.init()`.
"""
if run_id is not None:
path = mlflow.tracking._get_model_log_dir(model_name=path, run_id=run_id)
path = mlflow.tracking.utils._get_model_log_dir(model_name=path, run_id=run_id)
return _load_model(os.path.join(path, "model.h2o"))
2 changes: 1 addition & 1 deletion mlflow/keras.py
Expand Up @@ -64,5 +64,5 @@ def load_model(path, run_id=None):
Load a Keras model from a local file (if run_id is None) or a run.
"""
if run_id is not None:
path = mlflow.tracking._get_model_log_dir(model_name=path, run_id=run_id)
path = mlflow.tracking.utils._get_model_log_dir(model_name=path, run_id=run_id)
return _load_model(os.path.join(path, "model.h5"))
4 changes: 2 additions & 2 deletions mlflow/models/__init__.py
Expand Up @@ -51,7 +51,7 @@ def log(cls, artifact_path, flavor, **kwargs):
"""
with TempDir() as tmp:
local_path = tmp.path("model")
run_id = mlflow.tracking._get_or_start_run().run_info.run_uuid
run_id = mlflow.tracking.fluent._get_or_start_run().info.run_uuid
mlflow_model = cls(artifact_path=artifact_path, run_id=run_id)
flavor.save_model(path=local_path, mlflow_model=mlflow_model, **kwargs)
mlflow.tracking.log_artifacts(local_path, artifact_path)
mlflow.tracking.fluent.log_artifacts(local_path, artifact_path)
33 changes: 19 additions & 14 deletions mlflow/projects/__init__.py
Expand Up @@ -13,10 +13,9 @@
from mlflow.projects.submitted_run import LocalSubmittedRun
from mlflow.projects import _project_spec
from mlflow.utils.exception import ExecutionException
from mlflow.entities.run_status import RunStatus
from mlflow.entities.source_type import SourceType
from mlflow.entities.param import Param
from mlflow.entities import RunStatus, SourceType, Param
import mlflow.tracking as tracking
from mlflow.tracking.fluent import _get_experiment_id, _get_git_commit


from mlflow.utils import process
Expand All @@ -36,7 +35,7 @@ def _run(uri, entry_point="main", version=None, parameters=None, experiment_id=N
Helper that delegates to the project-running method corresponding to the passed-in mode.
Returns a ``SubmittedRun`` corresponding to the project run.
"""
exp_id = experiment_id or tracking._get_experiment_id()
exp_id = experiment_id or _get_experiment_id()
parameters = parameters or {}
if mode == "databricks":
from mlflow.projects.databricks import run_databricks
Expand All @@ -53,7 +52,7 @@ def _run(uri, entry_point="main", version=None, parameters=None, experiment_id=N
# failures due to multiple concurrent attempts to create the same conda env.
conda_env_name = _get_or_create_conda_env(project.conda_env_path) if use_conda else None
if run_id:
active_run = tracking._get_existing_run(run_id)
active_run = tracking.get_service().get_run(run_id)
else:
active_run = _create_run(uri, exp_id, work_dir, entry_point, parameters)
# In blocking mode, run the entry point command in blocking fashion, sending status updates
Expand All @@ -62,11 +61,11 @@ def _run(uri, entry_point="main", version=None, parameters=None, experiment_id=N
if block:
command = _get_entry_point_command(
project, entry_point, parameters, conda_env_name, storage_dir)
return _run_entry_point(command, work_dir, exp_id, run_id=active_run.run_info.run_uuid)
return _run_entry_point(command, work_dir, exp_id, run_id=active_run.info.run_uuid)
# Otherwise, invoke `mlflow run` in a subprocess
return _invoke_mlflow_run_subprocess(
work_dir=work_dir, entry_point=entry_point, parameters=parameters, experiment_id=exp_id,
use_conda=use_conda, storage_dir=storage_dir, run_id=active_run.run_info.run_uuid)
use_conda=use_conda, storage_dir=storage_dir, run_id=active_run.info.run_uuid)
supported_modes = ["local", "databricks"]
raise ExecutionException("Got unsupported execution mode %s. Supported "
"values: %s" % (mode, supported_modes))
Expand Down Expand Up @@ -131,7 +130,7 @@ def _wait_for(submitted_run_obj):
# Note: there's a small chance we fail to report the run's status to the tracking server if
# we're interrupted before we reach the try block below
try:
active_run = tracking._get_existing_run(run_id) if run_id is not None else None
active_run = tracking.get_service().get_run(run_id) if run_id is not None else None
if submitted_run_obj.wait():
eprint("=== Run (ID '%s') succeeded ===" % run_id)
_maybe_set_run_terminated(active_run, "FINISHED")
Expand Down Expand Up @@ -286,8 +285,13 @@ def _maybe_set_run_terminated(active_run, status):
If the passed-in active run is defined and still running (i.e. hasn't already been terminated
within user code), mark it as terminated with the passed-in status.
"""
if active_run and not RunStatus.is_terminated(active_run.get_run().info.status):
active_run.set_terminated(status)
if active_run is None:
return
run_id = active_run.info.run_uuid
cur_status = tracking.get_service().get_run(run_id).info.status
if RunStatus.is_terminated(cur_status):
return
tracking.get_service().set_terminated(run_id, status)


def _get_entry_point_command(project, entry_point, parameters, conda_env_name, storage_dir):
Expand Down Expand Up @@ -364,17 +368,18 @@ def _create_run(uri, experiment_id, work_dir, entry_point, parameters):
used to report additional data about the run (metrics/params) to the tracking server.
"""
if _is_local_uri(uri):
source_name = tracking._get_git_url_if_present(_expand_uri(uri))
source_name = tracking.utils._get_git_url_if_present(_expand_uri(uri))
else:
source_name = _expand_uri(uri)
active_run = tracking._create_run(
active_run = tracking.get_service().create_run(
experiment_id=experiment_id,
source_name=source_name,
source_version=tracking._get_git_commit(work_dir), entry_point_name=entry_point,
source_version=_get_git_commit(work_dir),
entry_point_name=entry_point,
source_type=SourceType.PROJECT)
if parameters is not None:
for key, value in parameters.items():
active_run.log_param(Param(key, value))
tracking.get_service().log_param(active_run.info.run_uuid, key, value)
return active_run


Expand Down
12 changes: 6 additions & 6 deletions mlflow/projects/databricks.py
Expand Up @@ -8,8 +8,7 @@

from six.moves import shlex_quote, urllib

from mlflow.entities.run_status import RunStatus
from mlflow.entities.source_type import SourceType
from mlflow.entities import RunStatus, SourceType


from mlflow.projects import _fetch_project, _expand_uri, _project_spec
Expand All @@ -18,6 +17,7 @@
from mlflow.utils.exception import ExecutionException
from mlflow.utils.logging_utils import eprint
from mlflow import tracking
from mlflow.tracking.fluent import _get_git_commit
from mlflow.version import VERSION

# Base directory within driver container for storing files related to MLflow
Expand Down Expand Up @@ -224,7 +224,7 @@ def _before_run_validations(tracking_uri, cluster_spec):
if cluster_spec is None:
raise ExecutionException("Cluster spec must be provided when launching MLflow project runs "
"on Databricks.")
if tracking.is_local_uri(tracking_uri):
if tracking.utils._is_local_uri(tracking_uri):
raise ExecutionException(
"When running on Databricks, the MLflow tracking URI must be set to a remote URI "
"accessible to both the current client and code running on Databricks. Got local "
Expand All @@ -244,15 +244,15 @@ def run_databricks(uri, entry_point, version, parameters, experiment_id, cluster
project = _project_spec.load_project(work_dir)
project.get_entry_point(entry_point)._validate_parameters(parameters)
dbfs_fuse_uri = _upload_project_to_dbfs(work_dir, experiment_id)
remote_run = tracking._create_run(
remote_run = tracking.get_service().create_run(
experiment_id=experiment_id, source_name=_expand_uri(uri),
source_version=tracking._get_git_commit(work_dir), entry_point_name=entry_point,
source_version=_get_git_commit(work_dir), entry_point_name=entry_point,
source_type=SourceType.PROJECT)
env_vars = {
tracking._TRACKING_URI_ENV_VAR: tracking_uri,
tracking._EXPERIMENT_ID_ENV_VAR: experiment_id,
}
run_id = remote_run.run_info.run_uuid
run_id = remote_run.info.run_uuid
eprint("=== Running entry point %s of project %s on Databricks. ===" % (entry_point, uri))
# Launch run on Databricks
with open(cluster_spec, 'r') as handle:
Expand Down
2 changes: 1 addition & 1 deletion mlflow/projects/submitted_run.py
Expand Up @@ -3,7 +3,7 @@
import os
import signal

from mlflow.entities.run_status import RunStatus
from mlflow.entities import RunStatus
from mlflow.utils.logging_utils import eprint


Expand Down
11 changes: 6 additions & 5 deletions mlflow/pyfunc/__init__.py
Expand Up @@ -79,6 +79,7 @@
import sys
import pandas

from mlflow.tracking.fluent import active_run, log_artifacts
from mlflow import tracking
from mlflow.models import Model
from mlflow.utils import PYTHON_VERSION, get_major_minor_py_version
Expand Down Expand Up @@ -124,7 +125,7 @@ def add_to_model(model, loader_module, data=None, code=None, env=None):
def _load_model_conf(path, run_id=None):
"""Load a model configuration stored in Python function format."""
if run_id:
path = tracking._get_model_log_dir(path, run_id)
path = tracking.utils._get_model_log_dir(path, run_id)
conf_path = os.path.join(path, "MLmodel")
model = Model.load(conf_path)
if FLAVOR_NAME not in model.flavors:
Expand All @@ -151,7 +152,7 @@ def load_pyfunc(path, run_id=None, suppress_warnings=False):
will be emitted.
"""
if run_id:
path = tracking._get_model_log_dir(path, run_id)
path = tracking.utils._get_model_log_dir(path, run_id)
conf = _load_model_conf(path)
model_py_version = conf.get(PY_VERSION)
if not suppress_warnings:
Expand Down Expand Up @@ -212,7 +213,7 @@ def spark_udf(spark, path, run_id=None, result_type="double"):
from pyspark.sql.functions import pandas_udf

if run_id:
path = tracking._get_model_log_dir(path, run_id)
path = tracking.utils._get_model_log_dir(path, run_id)

archive_path = SparkModelCache.add_local_model(spark, path)

Expand Down Expand Up @@ -289,13 +290,13 @@ def log_model(artifact_path, **kwargs):
"""
with TempDir() as tmp:
local_path = tmp.path(artifact_path)
run_id = tracking.active_run().info.run_uuid
run_id = active_run().info.run_uuid
if 'model' in kwargs:
raise Exception("Unused argument 'model'. log_model creates a new model object")

save_model(dst_path=local_path, model=Model(artifact_path=artifact_path, run_id=run_id),
**kwargs)
tracking.log_artifacts(local_path, artifact_path)
log_artifacts(local_path, artifact_path)


def get_module_loader_src(src_path, dst_path):
Expand Down
2 changes: 1 addition & 1 deletion mlflow/pyfunc/cli.py
Expand Up @@ -10,7 +10,7 @@
import pandas

from mlflow.pyfunc import load_pyfunc, scoring_server, _load_model_env
from mlflow.tracking import _get_model_log_dir
from mlflow.tracking.utils import _get_model_log_dir
from mlflow.utils import cli_args
from mlflow.utils.logging_utils import eprint
from mlflow.projects import _get_conda_bin_executable, _get_or_create_conda_env
Expand Down
2 changes: 1 addition & 1 deletion mlflow/sagemaker/__init__.py
Expand Up @@ -13,7 +13,7 @@
import mlflow.version
from mlflow import pyfunc
from mlflow.models import Model
from mlflow.tracking import _get_model_log_dir
from mlflow.tracking.utils import _get_model_log_dir
from mlflow.utils.logging_utils import eprint
from mlflow.utils.file_utils import TempDir, _copy_project

Expand Down
4 changes: 1 addition & 3 deletions mlflow/server/handlers.py
Expand Up @@ -8,9 +8,7 @@
from google.protobuf.json_format import MessageToJson, ParseDict
from querystring_parser import parser

from mlflow.entities.metric import Metric
from mlflow.entities.param import Param
from mlflow.entities.run_tag import RunTag
from mlflow.entities import Metric, Param, RunTag
from mlflow.protos import databricks_pb2
from mlflow.protos.service_pb2 import CreateExperiment, MlflowService, GetExperiment, \
GetRun, SearchRuns, ListArtifacts, GetArtifact, GetMetricHistory, CreateRun, \
Expand Down
7 changes: 4 additions & 3 deletions mlflow/sklearn.py
Expand Up @@ -14,6 +14,7 @@
from mlflow.utils.file_utils import TempDir
from mlflow import pyfunc
from mlflow.models import Model
from mlflow.tracking.fluent import _get_or_start_run, log_artifacts
import mlflow.tracking


Expand Down Expand Up @@ -48,10 +49,10 @@ def log_model(sk_model, artifact_path):
with TempDir() as tmp:
local_path = tmp.path("model")
# TODO: I get active_run_id here but mlflow.tracking.log_output_files has its own way
run_id = mlflow.tracking._get_or_start_run().run_info.run_uuid
run_id = _get_or_start_run().info.run_uuid
mlflow_model = Model(artifact_path=artifact_path, run_id=run_id)
save_model(sk_model, local_path, mlflow_model=mlflow_model)
mlflow.tracking.log_artifacts(local_path, artifact_path)
log_artifacts(local_path, artifact_path)


def _load_model_from_local_file(path):
Expand All @@ -74,7 +75,7 @@ def load_pyfunc(path):
def load_model(path, run_id=None):
"""Load a scikit-learn model from a local file (if ``run_id`` is None) or a run."""
if run_id is not None:
path = mlflow.tracking._get_model_log_dir(model_name=path, run_id=run_id)
path = mlflow.tracking.utils._get_model_log_dir(model_name=path, run_id=run_id)
return _load_model_from_local_file(path)


Expand Down
2 changes: 1 addition & 1 deletion mlflow/spark.py
Expand Up @@ -191,7 +191,7 @@ def load_model(path, run_id=None, dfs_tmpdir=DFS_TMP):

"""
if run_id is not None:
path = mlflow.tracking._get_model_log_dir(model_name=path, run_id=run_id)
path = mlflow.tracking.utils._get_model_log_dir(model_name=path, run_id=run_id)
m = Model.load(os.path.join(path, 'MLmodel'))
if FLAVOR_NAME not in m.flavors:
raise Exception("Model does not have {} flavor".format(FLAVOR_NAME))
Expand Down