New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Introduce time-based batch metrics logging and change XGBoost to use it #3619
Changes from 28 commits
35b624b
d8fff96
a3fef38
1eb5da3
7eb0d4c
9b1a782
c937852
cf22cb8
13ee723
d74d638
05aac0e
d99f981
8c59ab8
2e9cbdf
9d15742
ee7614e
055971e
be1db98
c5a11ae
a36dcf5
6aab196
e36549e
0ec68db
1cb6057
b106659
fd61abd
1304367
c7b41ce
3c18c8b
1d0c8aa
f665e4f
bd71788
26dc835
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
@@ -1,9 +1,14 @@ | ||||||||||||||
import inspect | ||||||||||||||
import functools | ||||||||||||||
import warnings | ||||||||||||||
import time | ||||||||||||||
import contextlib | ||||||||||||||
|
||||||||||||||
import mlflow | ||||||||||||||
from mlflow.utils import gorilla | ||||||||||||||
from mlflow.entities import Metric | ||||||||||||||
from mlflow.tracking.client import MlflowClient | ||||||||||||||
from mlflow.utils.validation import MAX_METRICS_PER_BATCH | ||||||||||||||
|
||||||||||||||
|
||||||||||||||
INPUT_EXAMPLE_SAMPLE_ROWS = 5 | ||||||||||||||
|
@@ -187,3 +192,83 @@ def resolve_input_example_and_signature( | |||||||||||||
logger.warning(model_signature_user_msg) | ||||||||||||||
|
||||||||||||||
return input_example if log_input_example else None, model_signature | ||||||||||||||
|
||||||||||||||
|
||||||||||||||
class BatchMetricsLogger: | ||||||||||||||
def __init__(self, run_id): | ||||||||||||||
self.run_id = run_id | ||||||||||||||
|
||||||||||||||
# data is an array of Metric objects | ||||||||||||||
self.data = [] | ||||||||||||||
self.total_training_time = 0 | ||||||||||||||
self.total_log_batch_time = 0 | ||||||||||||||
self.previous_training_timestamp = None | ||||||||||||||
|
||||||||||||||
def _purge(self): | ||||||||||||||
self._timed_log_batch() | ||||||||||||||
self.data = [] | ||||||||||||||
|
||||||||||||||
def _timed_log_batch(self): | ||||||||||||||
start = time.time() | ||||||||||||||
metrics_slices = [ | ||||||||||||||
self.data[i : i + MAX_METRICS_PER_BATCH] | ||||||||||||||
for i in range(0, len(self.data), MAX_METRICS_PER_BATCH) | ||||||||||||||
] | ||||||||||||||
for metrics_slice in metrics_slices: | ||||||||||||||
MlflowClient().log_batch(run_id=self.run_id, metrics=metrics_slice) | ||||||||||||||
end = time.time() | ||||||||||||||
self.total_log_batch_time += end - start | ||||||||||||||
|
||||||||||||||
def _should_purge(self): | ||||||||||||||
if self.total_log_batch_time == 0: # we don't yet have data on how long logging takes | ||||||||||||||
return True | ||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Now I don't think we need this :) |
||||||||||||||
|
||||||||||||||
log_batch_time_fudge_factor = 10 | ||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fudge factor seems like the wrong term here. This is the desired ratio of training time to batch logging time. Perhaps |
||||||||||||||
if self.total_training_time >= self.total_log_batch_time * log_batch_time_fudge_factor: | ||||||||||||||
return True | ||||||||||||||
|
||||||||||||||
return False | ||||||||||||||
|
||||||||||||||
def record_metrics(self, metrics, step): | ||||||||||||||
""" | ||||||||||||||
Submit a set of metrics to be logged. The metrics may not be immediately logged, as this | ||||||||||||||
class will batch them in order to not increase execution time too much by logging | ||||||||||||||
frequently. | ||||||||||||||
|
||||||||||||||
:param metrics: dictionary containing key, value pairs of metrics to be logged. | ||||||||||||||
:param step: the training step that the metrics correspond to. | ||||||||||||||
""" | ||||||||||||||
current_timestamp = time.time() | ||||||||||||||
if self.previous_training_timestamp is None: | ||||||||||||||
self.previous_training_timestamp = current_timestamp | ||||||||||||||
|
||||||||||||||
training_time = current_timestamp - self.previous_training_timestamp | ||||||||||||||
|
||||||||||||||
self.total_training_time += training_time | ||||||||||||||
|
||||||||||||||
for key, value in metrics.items(): | ||||||||||||||
self.data.append(Metric(key, value, current_timestamp, step)) | ||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the metric timestamp must be an integer value with millisecond resolution; e.g., this should be the following (based on what we do in the fluent API - mlflow/mlflow/tracking/fluent.py Line 443 in 055978c
Suggested change
I found this by running our XGBoost example, where I encountered warning logs from the file store about operating on timestamp content in float form, rather than integer form:
^ Interestingly, the error is only encountered when subsequent artifact logging operations are called, at which point the file store reads the logged metric files (not sure why it needs to do this, but it does) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we make sure to add a test case for this, ensuring that timestamps have millisecond resolution and are integer values? |
||||||||||||||
|
||||||||||||||
if self._should_purge(): | ||||||||||||||
self._purge() | ||||||||||||||
|
||||||||||||||
self.previous_training_timestamp = current_timestamp | ||||||||||||||
|
||||||||||||||
|
||||||||||||||
@contextlib.contextmanager | ||||||||||||||
def with_batch_metrics_logger(run_id): | ||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit: Let's drop the leading
Suggested change
|
||||||||||||||
""" | ||||||||||||||
Context manager that yields a BatchMetricsLogger object, which metrics can be logged against. | ||||||||||||||
The BatchMetricsLogger will keep metrics in a list until it decides they should be logged, at | ||||||||||||||
which point the accumulated metrics will be batch logged. The BatchMetricsLogger will ensure | ||||||||||||||
that logging imposes no more than a 10% overhead on the training, where the training is | ||||||||||||||
measured by adding up the time elapsed between consecutive calls to record_metrics. | ||||||||||||||
|
||||||||||||||
Once the context is closed, any metrics that have yet to be logged will be logged. | ||||||||||||||
|
||||||||||||||
:param run_id: ID of the run that the metrics will be logged to. | ||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit: Instead of future tense, we should favor using present tense where possible (e.g., |
||||||||||||||
""" | ||||||||||||||
|
||||||||||||||
with_batch_metrics_logger = BatchMetricsLogger(run_id) | ||||||||||||||
yield with_batch_metrics_logger | ||||||||||||||
with_batch_metrics_logger._purge() | ||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -45,6 +45,7 @@ | |
resolve_input_example_and_signature, | ||
_InputExampleInfo, | ||
ENSURE_AUTOLOGGING_ENABLED_TEXT, | ||
with_batch_metrics_logger, | ||
) | ||
from mlflow.tracking._model_registry import DEFAULT_AWAIT_MAX_SLEEP_SECONDS | ||
|
||
|
@@ -334,12 +335,13 @@ def __init__(self, *args, **kwargs): | |
original(self, *args, **kwargs) | ||
|
||
def train(*args, **kwargs): | ||
def record_eval_results(eval_results): | ||
def record_eval_results(eval_results, batch_metrics_logger): | ||
""" | ||
Create a callback function that records evaluation results. | ||
""" | ||
|
||
def callback(env): | ||
batch_metrics_logger.record_metrics(dict(env.evaluation_result_list), env.iteration) | ||
eval_results.append(dict(env.evaluation_result_list)) | ||
|
||
return callback | ||
|
@@ -413,22 +415,21 @@ def log_feature_importance_plot(features, importance, importance_type): | |
# adding a callback that records evaluation results. | ||
eval_results = [] | ||
callbacks_index = all_arg_names.index("callbacks") | ||
callback = record_eval_results(eval_results) | ||
if num_pos_args >= callbacks_index + 1: | ||
tmp_list = list(args) | ||
tmp_list[callbacks_index] += [callback] | ||
args = tuple(tmp_list) | ||
elif "callbacks" in kwargs and kwargs["callbacks"] is not None: | ||
kwargs["callbacks"] += [callback] | ||
else: | ||
kwargs["callbacks"] = [callback] | ||
|
||
# training model | ||
model = original(*args, **kwargs) | ||
|
||
# logging metrics on each iteration. | ||
for idx, metrics in enumerate(eval_results): | ||
try_mlflow_log(mlflow.log_metrics, metrics, step=idx) | ||
run_id = mlflow.tracking.fluent._get_or_start_run().info.run_id | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we use |
||
with with_batch_metrics_logger(run_id) as batch_metrics_logger: | ||
callback = record_eval_results(eval_results, batch_metrics_logger) | ||
if num_pos_args >= callbacks_index + 1: | ||
tmp_list = list(args) | ||
tmp_list[callbacks_index] += [callback] | ||
args = tuple(tmp_list) | ||
elif "callbacks" in kwargs and kwargs["callbacks"] is not None: | ||
kwargs["callbacks"] += [callback] | ||
else: | ||
kwargs["callbacks"] = [callback] | ||
|
||
# training model | ||
model = original(*args, **kwargs) | ||
|
||
# If early_stopping_rounds is present, logging metrics at the best iteration | ||
# as extra metrics with the max step + 1. | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,14 +1,19 @@ | ||
import inspect | ||
import time | ||
import pytest | ||
from unittest.mock import Mock, call | ||
from unittest import mock | ||
|
||
|
||
import mlflow | ||
from mlflow.utils import gorilla | ||
from mlflow.tracking.client import MlflowClient | ||
from mlflow.utils.autologging_utils import ( | ||
get_unspecified_default_args, | ||
log_fn_args_as_params, | ||
wrap_patch, | ||
resolve_input_example_and_signature, | ||
with_batch_metrics_logger, | ||
) | ||
|
||
# Example function signature we are testing on | ||
|
@@ -263,3 +268,100 @@ def modifies(_): | |
|
||
assert x["data"] == 0 | ||
logger.warning.assert_not_called() | ||
|
||
|
||
def test_batch_metrics_logger_logs_all_metrics(start_run): # pylint: disable=unused-argument | ||
with mock.patch.object(MlflowClient, "log_batch") as log_batch_mock: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we need to mock There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do we want to do this for every test? or is it sufficient to do it for one test and know that it works end-to-end, then save time by keeping the mock for the rest of the tests There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. actually it seems this really only applies to one of the tests.. the rest care about the intermediate state AS the BatchMetricsLogger is logging, not the final outcome, so it makes more sense to use mocks there |
||
run_id = mlflow.tracking.fluent._get_or_start_run().info.run_id | ||
with with_batch_metrics_logger(run_id) as batch_metrics_logger: | ||
for i in range(100): | ||
batch_metrics_logger.record_metrics({"x": i}, i) | ||
|
||
# collect the args of all the logging calls | ||
recorded_metrics = [] | ||
for call in log_batch_mock.call_args_list: | ||
_, kwargs = call | ||
metrics_arr = kwargs["metrics"] | ||
for metric in metrics_arr: | ||
recorded_metrics.append({metric._key: metric._value}) | ||
|
||
desired_metrics = [{"x": i} for i in range(100)] | ||
|
||
assert recorded_metrics == desired_metrics | ||
|
||
|
||
def test_batch_metrics_logger_runs_training_and_logging_in_correct_ratio( | ||
start_run, | ||
): # pylint: disable=unused-argument | ||
with mock.patch.object(MlflowClient, "log_batch") as log_batch_mock: | ||
run_id = mlflow.tracking.fluent._get_or_start_run().info.run_id | ||
with with_batch_metrics_logger(run_id) as batch_metrics_logger: | ||
batch_metrics_logger.record_metrics({"x": 1}, step=0) # data doesn't matter | ||
# first metrics should be skipped to record a previous timestamp and batch log time | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
log_batch_mock.assert_called_once() | ||
|
||
batch_metrics_logger.total_log_batch_time = 1 | ||
batch_metrics_logger.total_training_time = 1 | ||
|
||
log_batch_mock.reset_mock() # resets the 'calls' of this mock | ||
|
||
# the above 'training' took 1 second. So with fudge factor of 10x, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As above, I don't think that fudge factor is accurate terminology for this use case |
||
# 9 more 'training' should happen before the metrics are sent and | ||
# then the 10th should send them. | ||
for i in range(2, 11): | ||
batch_metrics_logger.record_metrics({"x": 1}, step=0) | ||
log_batch_mock.assert_not_called() | ||
batch_metrics_logger.total_training_time = i | ||
|
||
# at this point, average log batch time is 1, and total training time is 10 | ||
# thus the next record_metrics call should send the batch. | ||
batch_metrics_logger.record_metrics({"x": 1}, step=0) | ||
log_batch_mock.assert_called_once() | ||
|
||
# update log_batch time to reflect the 'mocked' training time | ||
batch_metrics_logger.total_log_batch_time = 2 | ||
|
||
log_batch_mock.reset_mock() # reset the recorded calls | ||
|
||
for i in range(12, 21): | ||
batch_metrics_logger.record_metrics({"x": 1}, step=0) | ||
log_batch_mock.assert_not_called() | ||
batch_metrics_logger.total_training_time = i | ||
|
||
batch_metrics_logger.record_metrics({"x": 1}, step=0) | ||
log_batch_mock.assert_called_once() | ||
|
||
|
||
def test_batch_metrics_logger_chunks_metrics_when_batch_logging( | ||
start_run, | ||
): # pylint: disable=unused-argument | ||
with mock.patch.object(MlflowClient, "log_batch") as log_batch_mock: | ||
run_id = mlflow.tracking.fluent._get_or_start_run().info.run_id | ||
with with_batch_metrics_logger(run_id) as batch_metrics_logger: | ||
batch_metrics_logger.record_metrics({hex(x): x for x in range(5000)}, step=0) | ||
run_id = mlflow.active_run().info.run_id | ||
|
||
for call_idx, call in enumerate(log_batch_mock.call_args_list): | ||
_, kwargs = call | ||
|
||
assert kwargs["run_id"] == run_id | ||
assert len(kwargs["metrics"]) == 1000 | ||
for metric_idx, metric in enumerate(kwargs["metrics"]): | ||
assert metric.key == hex(call_idx * 1000 + metric_idx) | ||
assert metric.value == call_idx * 1000 + metric_idx | ||
assert metric.step == 0 | ||
|
||
|
||
def test_batch_metrics_logger_records_time_correctly(start_run,): # pylint: disable=unused-argument | ||
with mock.patch.object(MlflowClient, "log_batch", wraps=lambda *args, **kwargs: time.sleep(1)): | ||
run_id = mlflow.tracking.fluent._get_or_start_run().info.run_id | ||
with with_batch_metrics_logger(run_id) as batch_metrics_logger: | ||
batch_metrics_logger.record_metrics({"x": 1}, step=0) | ||
|
||
assert batch_metrics_logger.total_log_batch_time >= 1 | ||
|
||
time.sleep(2) | ||
|
||
batch_metrics_logger.record_metrics({"x": 1}, step=0) | ||
|
||
assert batch_metrics_logger.total_training_time >= 2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we wrap this in a call to
try_mlflow_log()
to ensure that failures don't prevent future metrics from being logged?We should document this behavior and, if possible, we should add a test case for it too.