New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Introduce time-based batch metrics logging and change XGBoost to use it #3619
Changes from 24 commits
35b624b
d8fff96
a3fef38
1eb5da3
7eb0d4c
9b1a782
c937852
cf22cb8
13ee723
d74d638
05aac0e
d99f981
8c59ab8
2e9cbdf
9d15742
ee7614e
055971e
be1db98
c5a11ae
a36dcf5
6aab196
e36549e
0ec68db
1cb6057
b106659
fd61abd
1304367
c7b41ce
3c18c8b
1d0c8aa
f665e4f
bd71788
26dc835
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,9 +1,14 @@ | ||
import inspect | ||
import functools | ||
import warnings | ||
import time | ||
import contextlib | ||
|
||
import mlflow | ||
from mlflow.utils import gorilla | ||
from mlflow.entities import Metric | ||
from mlflow.tracking.client import MlflowClient | ||
from mlflow.utils.validation import MAX_METRICS_PER_BATCH | ||
|
||
|
||
INPUT_EXAMPLE_SAMPLE_ROWS = 5 | ||
|
@@ -187,3 +192,98 @@ def resolve_input_example_and_signature( | |
logger.warning(model_signature_user_msg) | ||
|
||
return input_example if log_input_example else None, model_signature | ||
|
||
|
||
# wrapper functions to be able to mock this easily in the tests | ||
def time_wrapper_for_log(): | ||
return time.time() | ||
|
||
|
||
def time_wrapper_for_current(): | ||
return time.time() | ||
|
||
|
||
def time_wrapper_for_timestamp(): | ||
return time.time() | ||
|
||
|
||
# we pass the batch_metrics_handler through, such that the callback can access it | ||
def _timed_log_batch(batch_metrics_handler, run_id, metrics): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we move this into |
||
start = time_wrapper_for_log() | ||
metrics_slices = [ | ||
metrics[i * MAX_METRICS_PER_BATCH : (i + 1) * MAX_METRICS_PER_BATCH] | ||
for i in range((len(metrics) + MAX_METRICS_PER_BATCH - 1) // MAX_METRICS_PER_BATCH) | ||
] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think using the
(Credits to https://stackoverflow.com/a/312464/11952869) |
||
for metrics_slice in metrics_slices: | ||
MlflowClient().log_batch(run_id=run_id, metrics=metrics_slice) | ||
end = time_wrapper_for_log() | ||
batch_metrics_handler.total_log_batch_time += end - start | ||
batch_metrics_handler.num_log_batch += 1 | ||
|
||
|
||
class BatchMetricsHandler: # BatchMetricsLogger maybe? | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. +1. Let's call this |
||
def __init__(self): | ||
# data is an array of tuples of the form (timestamp, metrics at timestamp) | ||
self.data = {} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Seems like a dictionary to me! (Though I think it should be a list of |
||
self.total_training_time = 0 | ||
self.total_log_batch_time = 0 | ||
self.num_log_batch = 0 | ||
self.previous_training_timestamp = None | ||
|
||
def _purge(self): | ||
run_id = mlflow.tracking.fluent._get_or_start_run().info.run_id | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
For simplicity, I think it makes sense to tie There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sure, that makes sense. I just needed the run_id and that was my first thought to grab it :D |
||
final_metrics = [] | ||
|
||
for step, metrics_at_step in self.data.items(): | ||
for entry in metrics_at_step: | ||
timestamp = entry[0] | ||
metrics_at_timestamp = entry[1] | ||
|
||
for key, value in metrics_at_timestamp.items(): | ||
final_metrics.append(Metric(key, value, timestamp, step)) | ||
|
||
_timed_log_batch(self, run_id=run_id, metrics=final_metrics) | ||
|
||
self.data = {} | ||
|
||
def _should_purge(self): | ||
if self.num_log_batch == 0: | ||
return True | ||
dbczumar marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
# we give some extra time in case of network slowdown | ||
log_batch_time_fudge_factor = 10 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fudge factor seems like the wrong term here. This is the desired ratio of training time to batch logging time. Perhaps |
||
if ( | ||
self.total_training_time | ||
>= self.total_log_batch_time / self.num_log_batch * log_batch_time_fudge_factor | ||
): | ||
return True | ||
|
||
return False | ||
|
||
# metrics is a dict representing the set of metrics collected during one iteration | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we convert this comment into a docstring? |
||
def record_metrics(self, metrics, step): | ||
current_timestamp = time_wrapper_for_current() | ||
if self.previous_training_timestamp is None: | ||
self.previous_training_timestamp = current_timestamp | ||
|
||
training_time = current_timestamp - self.previous_training_timestamp | ||
|
||
self.total_training_time += training_time | ||
|
||
if step in self.data: | ||
self.data[step].append([int(time_wrapper_for_timestamp() * 1000), metrics]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. On second thought, I'm not sure we need these timer wrappers. See #3619 (comment) |
||
else: | ||
self.data[step] = [[int(time_wrapper_for_timestamp() * 1000), metrics]] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we construct There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah i don't see why not. Seems kinda wasteful to group them by step then ungroup then again as i'm doing now |
||
|
||
if self._should_purge(): | ||
self.total_training_time = 0 | ||
andrewnitu marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self._purge() | ||
|
||
self.previous_training_timestamp = current_timestamp | ||
|
||
|
||
@contextlib.contextmanager | ||
def with_batch_metrics_handler(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we add docs here? |
||
batch_metrics_handler = BatchMetricsHandler() | ||
yield batch_metrics_handler | ||
batch_metrics_handler._purge() |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -45,6 +45,7 @@ | |
resolve_input_example_and_signature, | ||
_InputExampleInfo, | ||
ENSURE_AUTOLOGGING_ENABLED_TEXT, | ||
with_batch_metrics_handler, | ||
) | ||
from mlflow.tracking._model_registry import DEFAULT_AWAIT_MAX_SLEEP_SECONDS | ||
|
||
|
@@ -340,6 +341,9 @@ def record_eval_results(eval_results): | |
""" | ||
|
||
def callback(env): | ||
batch_metrics_handler.record_metrics( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we add |
||
dict(env.evaluation_result_list), env.iteration | ||
) | ||
eval_results.append(dict(env.evaluation_result_list)) | ||
|
||
return callback | ||
|
@@ -423,12 +427,10 @@ def log_feature_importance_plot(features, importance, importance_type): | |
else: | ||
kwargs["callbacks"] = [callback] | ||
|
||
# training model | ||
model = original(*args, **kwargs) | ||
|
||
# logging metrics on each iteration. | ||
for idx, metrics in enumerate(eval_results): | ||
try_mlflow_log(mlflow.log_metrics, metrics, step=idx) | ||
# logging metrics on each iteration | ||
with with_batch_metrics_handler() as batch_metrics_handler: | ||
# training model | ||
model = original(*args, **kwargs) | ||
|
||
# If early_stopping_rounds is present, logging metrics at the best iteration | ||
# as extra metrics with the max step + 1. | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,14 +1,19 @@ | ||
import inspect | ||
import pytest | ||
import itertools | ||
from unittest.mock import Mock, call | ||
from unittest import mock | ||
|
||
|
||
import mlflow | ||
from mlflow.utils import gorilla | ||
from mlflow.tracking.client import MlflowClient | ||
from mlflow.utils.autologging_utils import ( | ||
get_unspecified_default_args, | ||
log_fn_args_as_params, | ||
wrap_patch, | ||
resolve_input_example_and_signature, | ||
with_batch_metrics_handler, | ||
) | ||
|
||
# Example function signature we are testing on | ||
|
@@ -263,3 +268,92 @@ def modifies(_): | |
|
||
assert x["data"] == 0 | ||
logger.warning.assert_not_called() | ||
|
||
|
||
def test_batch_metrics_handler_logs_all_metrics(start_run): # pylint: disable=unused-argument | ||
with mock.patch.object(MlflowClient, "log_batch") as log_batch_mock: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we need to mock There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do we want to do this for every test? or is it sufficient to do it for one test and know that it works end-to-end, then save time by keeping the mock for the rest of the tests There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. actually it seems this really only applies to one of the tests.. the rest care about the intermediate state AS the BatchMetricsLogger is logging, not the final outcome, so it makes more sense to use mocks there |
||
with with_batch_metrics_handler() as batch_metrics_handler: | ||
for i in range(100): | ||
batch_metrics_handler.record_metrics({"x": i}, i) | ||
|
||
# collect the args of all the logging calls | ||
recorded_metrics = [] | ||
for call in log_batch_mock.call_args_list: | ||
_, kwargs = call | ||
metrics_arr = kwargs["metrics"] | ||
for metric in metrics_arr: | ||
recorded_metrics.append({metric._key: metric._value}) | ||
|
||
desired_metrics = [{"x": i} for i in range(100)] | ||
|
||
assert recorded_metrics == desired_metrics | ||
|
||
|
||
def test_batch_metrics_handler_runs_training_and_logging_in_correct_ratio( | ||
start_run, | ||
): # pylint: disable=unused-argument | ||
with mock.patch.object(MlflowClient, "log_batch") as log_batch_mock, mock.patch( | ||
"mlflow.utils.autologging_utils.time_wrapper_for_log" | ||
) as log_time_mock, mock.patch( | ||
"mlflow.utils.autologging_utils.time_wrapper_for_current" | ||
) as current_time_mock, mock.patch( | ||
"mlflow.utils.autologging_utils.time_wrapper_for_timestamp" | ||
) as timestamp_time_mock: | ||
current_time_mock.side_effect = [ | ||
0, | ||
1, | ||
2, | ||
3, | ||
4, | ||
5, | ||
6, | ||
7, | ||
8, | ||
9, | ||
10, | ||
11, | ||
] # training occurs every second | ||
log_time_mock.side_effect = itertools.cycle( | ||
[100, 101] | ||
) # logging takes 1 second, numbers don't matter here | ||
timestamp_time_mock.return_value = 9999 # this doesn't matter | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If this number doesn't matter, do we need to mock it? |
||
|
||
with with_batch_metrics_handler() as batch_metrics_handler: | ||
batch_metrics_handler.record_metrics({"x": 1}, step=0) # data doesn't matter | ||
# first metrics should be skipped to record a previous timestamp and batch log time | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
log_batch_mock.assert_called_once() | ||
|
||
log_batch_mock.reset_mock() # resets the 'calls' of this mock | ||
|
||
# the above 'training' took 1 second. So with fudge factor of 10x, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As above, I don't think that fudge factor is accurate terminology for this use case |
||
# 10 more 'training' should happen before the metrics are sent. | ||
for _ in range(9): | ||
batch_metrics_handler.record_metrics({"x": 1}, step=0) | ||
log_batch_mock.assert_not_called() | ||
|
||
batch_metrics_handler.record_metrics({"x": 1}, step=0) | ||
log_batch_mock.assert_called_once() | ||
|
||
|
||
def test_batch_metrics_handler_chunks_metrics_when_batch_logging( | ||
start_run, | ||
): # pylint: disable=unused-argument | ||
with mock.patch.object(MlflowClient, "log_batch") as log_batch_mock, mock.patch( | ||
"mlflow.utils.autologging_utils.time_wrapper_for_timestamp" | ||
) as timestamp_time_mock: | ||
timestamp_time_mock.return_value = 0 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this timestamp mock necessary? If not, it seems that we can get rid of |
||
|
||
with with_batch_metrics_handler() as batch_metrics_handler: | ||
batch_metrics_handler.record_metrics({hex(x): x for x in range(5000)}, step=0) | ||
run_id = mlflow.active_run().info.run_id | ||
|
||
for call_idx, call in enumerate(log_batch_mock.call_args_list): | ||
_, kwargs = call | ||
|
||
assert kwargs["run_id"] == run_id | ||
assert len(kwargs["metrics"]) == 1000 | ||
for metric_idx, metric in enumerate(kwargs["metrics"]): | ||
assert metric.key == hex(call_idx * 1000 + metric_idx) | ||
assert metric.timestamp == 0 | ||
assert metric.value == call_idx * 1000 + metric_idx | ||
assert metric.step == 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of wrapping
time.time()
for mocking, can we just manipulate thetotal_log_batch_time
andtotal_training_time
properties ofBatchMetricsHandler
in our test cases?If we're concerned about the measurement of
total_training_time
andtotal_log_batch_time
, we can always construct another test case that performs sleeps to simulate training / logging and then verifies thattotal_training_time
/total_log_batch_time
exceed expected thresholds.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i tried to test this but for some reason at least in pytest sleep doesnt increase the system clock, even though the test is obviously taking longer (so the sleep is running)