Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce time-based batch metrics logging and change XGBoost to use it #3619

Merged
Merged
Show file tree
Hide file tree
Changes from 28 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
35b624b
xgboost log on every iteration with timing
andrewnitu Oct 27, 2020
d8fff96
get avg time
andrewnitu Oct 27, 2020
a3fef38
fix
andrewnitu Oct 27, 2020
1eb5da3
batch send all at the end of training
andrewnitu Oct 27, 2020
7eb0d4c
stash
andrewnitu Oct 29, 2020
9b1a782
rename promise to future
andrewnitu Oct 29, 2020
c937852
remove batch_log_interval
andrewnitu Oct 29, 2020
cf22cb8
make should_purge have no side effects
andrewnitu Oct 29, 2020
13ee723
do not assume step anymore
andrewnitu Oct 31, 2020
d74d638
add test case
andrewnitu Oct 31, 2020
05aac0e
stash
andrewnitu Nov 2, 2020
d99f981
autofmt
andrewnitu Nov 2, 2020
8c59ab8
linting
andrewnitu Nov 2, 2020
2e9cbdf
some cleanup and gather batch log time on initial iteration
andrewnitu Nov 2, 2020
9d15742
more cleanup
andrewnitu Nov 2, 2020
ee7614e
reimport time
andrewnitu Nov 2, 2020
055971e
revert changes to xgboost example
andrewnitu Nov 2, 2020
be1db98
add chunking test and clean up tests
andrewnitu Nov 2, 2020
c5a11ae
refactor chunking test
andrewnitu Nov 2, 2020
a36dcf5
revert adding __eq__ method to metric entity
andrewnitu Nov 2, 2020
6aab196
remove commented-out code
andrewnitu Nov 2, 2020
e36549e
fix xgboost autolog tests
andrewnitu Nov 2, 2020
0ec68db
remove unused import
andrewnitu Nov 2, 2020
1cb6057
remove unused import
andrewnitu Nov 3, 2020
b106659
code review
andrewnitu Nov 4, 2020
fd61abd
fix line lenght
andrewnitu Nov 4, 2020
1304367
change to total log batch time instead of average
andrewnitu Nov 4, 2020
c7b41ce
make test go through two cycles of batch logging
andrewnitu Nov 5, 2020
3c18c8b
code review
andrewnitu Nov 5, 2020
1d0c8aa
some code review
andrewnitu Nov 6, 2020
f665e4f
code review
andrewnitu Nov 9, 2020
bd71788
remove extra param from xgboost example
andrewnitu Nov 10, 2020
26dc835
nit fix
andrewnitu Nov 11, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
85 changes: 85 additions & 0 deletions mlflow/utils/autologging_utils.py
@@ -1,9 +1,14 @@
import inspect
import functools
import warnings
import time
import contextlib

import mlflow
from mlflow.utils import gorilla
from mlflow.entities import Metric
from mlflow.tracking.client import MlflowClient
from mlflow.utils.validation import MAX_METRICS_PER_BATCH


INPUT_EXAMPLE_SAMPLE_ROWS = 5
Expand Down Expand Up @@ -187,3 +192,83 @@ def resolve_input_example_and_signature(
logger.warning(model_signature_user_msg)

return input_example if log_input_example else None, model_signature


class BatchMetricsLogger:
def __init__(self, run_id):
self.run_id = run_id

# data is an array of Metric objects
self.data = []
self.total_training_time = 0
self.total_log_batch_time = 0
self.previous_training_timestamp = None

def _purge(self):
self._timed_log_batch()
self.data = []

def _timed_log_batch(self):
start = time.time()
metrics_slices = [
self.data[i : i + MAX_METRICS_PER_BATCH]
for i in range(0, len(self.data), MAX_METRICS_PER_BATCH)
]
for metrics_slice in metrics_slices:
MlflowClient().log_batch(run_id=self.run_id, metrics=metrics_slice)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we wrap this in a call to try_mlflow_log() to ensure that failures don't prevent future metrics from being logged?

We should document this behavior and, if possible, we should add a test case for it too.

end = time.time()
self.total_log_batch_time += end - start

def _should_purge(self):
if self.total_log_batch_time == 0: # we don't yet have data on how long logging takes
return True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now I don't think we need this :)


log_batch_time_fudge_factor = 10
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fudge factor seems like the wrong term here. This is the desired ratio of training time to batch logging time. Perhaps target_training_to_logging_time_ratio?

if self.total_training_time >= self.total_log_batch_time * log_batch_time_fudge_factor:
return True

return False

def record_metrics(self, metrics, step):
"""
Submit a set of metrics to be logged. The metrics may not be immediately logged, as this
class will batch them in order to not increase execution time too much by logging
frequently.

:param metrics: dictionary containing key, value pairs of metrics to be logged.
:param step: the training step that the metrics correspond to.
"""
current_timestamp = time.time()
if self.previous_training_timestamp is None:
self.previous_training_timestamp = current_timestamp

training_time = current_timestamp - self.previous_training_timestamp

self.total_training_time += training_time

for key, value in metrics.items():
self.data.append(Metric(key, value, current_timestamp, step))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the metric timestamp must be an integer value with millisecond resolution; e.g., this should be the following (based on what we do in the fluent API -

timestamp = int(time.time() * 1000)
):

Suggested change
self.data.append(Metric(key, value, current_timestamp, step))
self.data.append(Metric(key, value, int(current_timestamp * 1000), step))

I found this by running our XGBoost example, where I encountered warning logs from the file store about operating on timestamp content in float form, rather than integer form:

/Users/czumar/mlflow/mlflow/xgboost.py:412: DeprecationWarning: inspect.getargspec() is deprecated since Python 3.0, use inspect.signature() or inspect.getfullargspec()
  all_arg_names = inspect.getargspec(original)[0]  # pylint: disable=W1505
[0]     train-mlogloss:0.74723
[1]     train-mlogloss:0.54060
[2]     train-mlogloss:0.40276
[3]     train-mlogloss:0.30789
[4]     train-mlogloss:0.24052
[5]     train-mlogloss:0.19087
[6]     train-mlogloss:0.15471
[7]     train-mlogloss:0.12807
[8]     train-mlogloss:0.10722
[9]     train-mlogloss:0.09053
/Users/czumar/mlflow/mlflow/xgboost.py:387: UserWarning: Logging to MLflow failed: invalid literal for int() with base 10: '1604679208.8766131'
  try_mlflow_log(mlflow.log_artifact, filepath)
/Users/czumar/mlflow/mlflow/xgboost.py:465: UserWarning: Logging to MLflow failed: invalid literal for int() with base 10: '1604679208.8766131'
  try_mlflow_log(mlflow.log_artifact, filepath)
/Users/czumar/mlflow/mlflow/xgboost.py:501: UserWarning: Logging to MLflow failed: invalid literal for int() with base 10: '1604679208.8766131'
  input_example=input_example,

^ Interestingly, the error is only encountered when subsequent artifact logging operations are called, at which point the file store reads the logged metric files (not sure why it needs to do this, but it does)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we make sure to add a test case for this, ensuring that timestamps have millisecond resolution and are integer values?


if self._should_purge():
self._purge()

self.previous_training_timestamp = current_timestamp


@contextlib.contextmanager
def with_batch_metrics_logger(run_id):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Let's drop the leading with so that usages become with batch_metrics_logger rather than with with_batch_metrics_logger:

Suggested change
def with_batch_metrics_logger(run_id):
def batch_metrics_logger(run_id):

"""
Context manager that yields a BatchMetricsLogger object, which metrics can be logged against.
The BatchMetricsLogger will keep metrics in a list until it decides they should be logged, at
which point the accumulated metrics will be batch logged. The BatchMetricsLogger will ensure
that logging imposes no more than a 10% overhead on the training, where the training is
measured by adding up the time elapsed between consecutive calls to record_metrics.

Once the context is closed, any metrics that have yet to be logged will be logged.

:param run_id: ID of the run that the metrics will be logged to.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Instead of future tense, we should favor using present tense where possible (e.g., The BatchMetricsLogger keeps metrics in a list ... instead of The BatchMetricsLogger will keep metrics in a list...)

"""

with_batch_metrics_logger = BatchMetricsLogger(run_id)
yield with_batch_metrics_logger
with_batch_metrics_logger._purge()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
with_batch_metrics_logger = BatchMetricsLogger(run_id)
yield with_batch_metrics_logger
with_batch_metrics_logger._purge()
metrics_logger = BatchMetricsLogger(run_id)
yield metrics_logger
metrics_logger._purge()

33 changes: 17 additions & 16 deletions mlflow/xgboost.py
Expand Up @@ -45,6 +45,7 @@
resolve_input_example_and_signature,
_InputExampleInfo,
ENSURE_AUTOLOGGING_ENABLED_TEXT,
with_batch_metrics_logger,
)
from mlflow.tracking._model_registry import DEFAULT_AWAIT_MAX_SLEEP_SECONDS

Expand Down Expand Up @@ -334,12 +335,13 @@ def __init__(self, *args, **kwargs):
original(self, *args, **kwargs)

def train(*args, **kwargs):
def record_eval_results(eval_results):
def record_eval_results(eval_results, batch_metrics_logger):
"""
Create a callback function that records evaluation results.
"""

def callback(env):
batch_metrics_logger.record_metrics(dict(env.evaluation_result_list), env.iteration)
eval_results.append(dict(env.evaluation_result_list))

return callback
Expand Down Expand Up @@ -413,22 +415,21 @@ def log_feature_importance_plot(features, importance, importance_type):
# adding a callback that records evaluation results.
eval_results = []
callbacks_index = all_arg_names.index("callbacks")
callback = record_eval_results(eval_results)
if num_pos_args >= callbacks_index + 1:
tmp_list = list(args)
tmp_list[callbacks_index] += [callback]
args = tuple(tmp_list)
elif "callbacks" in kwargs and kwargs["callbacks"] is not None:
kwargs["callbacks"] += [callback]
else:
kwargs["callbacks"] = [callback]

# training model
model = original(*args, **kwargs)

# logging metrics on each iteration.
for idx, metrics in enumerate(eval_results):
try_mlflow_log(mlflow.log_metrics, metrics, step=idx)
run_id = mlflow.tracking.fluent._get_or_start_run().info.run_id
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use mlflow.active_run().info.run_id instead? A run should have already been created on line 350 (either that or an active run already existed prior to autologging)

with with_batch_metrics_logger(run_id) as batch_metrics_logger:
callback = record_eval_results(eval_results, batch_metrics_logger)
if num_pos_args >= callbacks_index + 1:
tmp_list = list(args)
tmp_list[callbacks_index] += [callback]
args = tuple(tmp_list)
elif "callbacks" in kwargs and kwargs["callbacks"] is not None:
kwargs["callbacks"] += [callback]
else:
kwargs["callbacks"] = [callback]

# training model
model = original(*args, **kwargs)

# If early_stopping_rounds is present, logging metrics at the best iteration
# as extra metrics with the max step + 1.
Expand Down
102 changes: 102 additions & 0 deletions tests/utils/test_autologging_utils.py
@@ -1,14 +1,19 @@
import inspect
import time
import pytest
from unittest.mock import Mock, call
from unittest import mock


import mlflow
from mlflow.utils import gorilla
from mlflow.tracking.client import MlflowClient
from mlflow.utils.autologging_utils import (
get_unspecified_default_args,
log_fn_args_as_params,
wrap_patch,
resolve_input_example_and_signature,
with_batch_metrics_logger,
)

# Example function signature we are testing on
Expand Down Expand Up @@ -263,3 +268,100 @@ def modifies(_):

assert x["data"] == 0
logger.warning.assert_not_called()


def test_batch_metrics_logger_logs_all_metrics(start_run): # pylint: disable=unused-argument
with mock.patch.object(MlflowClient, "log_batch") as log_batch_mock:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to mock log_batch? Ideally, it would be nice to test that the metrics are actually logged by leaving this unmocked and querying the run data after the iterative calls to record_metrics() complete.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we want to do this for every test? or is it sufficient to do it for one test and know that it works end-to-end, then save time by keeping the mock for the rest of the tests

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually it seems this really only applies to one of the tests.. the rest care about the intermediate state AS the BatchMetricsLogger is logging, not the final outcome, so it makes more sense to use mocks there

run_id = mlflow.tracking.fluent._get_or_start_run().info.run_id
with with_batch_metrics_logger(run_id) as batch_metrics_logger:
for i in range(100):
batch_metrics_logger.record_metrics({"x": i}, i)

# collect the args of all the logging calls
recorded_metrics = []
for call in log_batch_mock.call_args_list:
_, kwargs = call
metrics_arr = kwargs["metrics"]
for metric in metrics_arr:
recorded_metrics.append({metric._key: metric._value})

desired_metrics = [{"x": i} for i in range(100)]

assert recorded_metrics == desired_metrics


def test_batch_metrics_logger_runs_training_and_logging_in_correct_ratio(
start_run,
): # pylint: disable=unused-argument
with mock.patch.object(MlflowClient, "log_batch") as log_batch_mock:
run_id = mlflow.tracking.fluent._get_or_start_run().info.run_id
with with_batch_metrics_logger(run_id) as batch_metrics_logger:
batch_metrics_logger.record_metrics({"x": 1}, step=0) # data doesn't matter
# first metrics should be skipped to record a previous timestamp and batch log time
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

skipped seems to imply that we're dropping metrics or not logging them; I think we mean that we're logging them immediately (i.e. "skipping waiting")

log_batch_mock.assert_called_once()

batch_metrics_logger.total_log_batch_time = 1
batch_metrics_logger.total_training_time = 1

log_batch_mock.reset_mock() # resets the 'calls' of this mock

# the above 'training' took 1 second. So with fudge factor of 10x,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As above, I don't think that fudge factor is accurate terminology for this use case

# 9 more 'training' should happen before the metrics are sent and
# then the 10th should send them.
for i in range(2, 11):
batch_metrics_logger.record_metrics({"x": 1}, step=0)
log_batch_mock.assert_not_called()
batch_metrics_logger.total_training_time = i

# at this point, average log batch time is 1, and total training time is 10
# thus the next record_metrics call should send the batch.
batch_metrics_logger.record_metrics({"x": 1}, step=0)
log_batch_mock.assert_called_once()

# update log_batch time to reflect the 'mocked' training time
batch_metrics_logger.total_log_batch_time = 2

log_batch_mock.reset_mock() # reset the recorded calls

for i in range(12, 21):
batch_metrics_logger.record_metrics({"x": 1}, step=0)
log_batch_mock.assert_not_called()
batch_metrics_logger.total_training_time = i

batch_metrics_logger.record_metrics({"x": 1}, step=0)
log_batch_mock.assert_called_once()


def test_batch_metrics_logger_chunks_metrics_when_batch_logging(
start_run,
): # pylint: disable=unused-argument
with mock.patch.object(MlflowClient, "log_batch") as log_batch_mock:
run_id = mlflow.tracking.fluent._get_or_start_run().info.run_id
with with_batch_metrics_logger(run_id) as batch_metrics_logger:
batch_metrics_logger.record_metrics({hex(x): x for x in range(5000)}, step=0)
run_id = mlflow.active_run().info.run_id

for call_idx, call in enumerate(log_batch_mock.call_args_list):
_, kwargs = call

assert kwargs["run_id"] == run_id
assert len(kwargs["metrics"]) == 1000
for metric_idx, metric in enumerate(kwargs["metrics"]):
assert metric.key == hex(call_idx * 1000 + metric_idx)
assert metric.value == call_idx * 1000 + metric_idx
assert metric.step == 0


def test_batch_metrics_logger_records_time_correctly(start_run,): # pylint: disable=unused-argument
with mock.patch.object(MlflowClient, "log_batch", wraps=lambda *args, **kwargs: time.sleep(1)):
run_id = mlflow.tracking.fluent._get_or_start_run().info.run_id
with with_batch_metrics_logger(run_id) as batch_metrics_logger:
batch_metrics_logger.record_metrics({"x": 1}, step=0)

assert batch_metrics_logger.total_log_batch_time >= 1

time.sleep(2)

batch_metrics_logger.record_metrics({"x": 1}, step=0)

assert batch_metrics_logger.total_training_time >= 2