Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce time-based batch metrics logging and change XGBoost to use it #3619

Merged
Merged
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
35b624b
xgboost log on every iteration with timing
andrewnitu Oct 27, 2020
d8fff96
get avg time
andrewnitu Oct 27, 2020
a3fef38
fix
andrewnitu Oct 27, 2020
1eb5da3
batch send all at the end of training
andrewnitu Oct 27, 2020
7eb0d4c
stash
andrewnitu Oct 29, 2020
9b1a782
rename promise to future
andrewnitu Oct 29, 2020
c937852
remove batch_log_interval
andrewnitu Oct 29, 2020
cf22cb8
make should_purge have no side effects
andrewnitu Oct 29, 2020
13ee723
do not assume step anymore
andrewnitu Oct 31, 2020
d74d638
add test case
andrewnitu Oct 31, 2020
05aac0e
stash
andrewnitu Nov 2, 2020
d99f981
autofmt
andrewnitu Nov 2, 2020
8c59ab8
linting
andrewnitu Nov 2, 2020
2e9cbdf
some cleanup and gather batch log time on initial iteration
andrewnitu Nov 2, 2020
9d15742
more cleanup
andrewnitu Nov 2, 2020
ee7614e
reimport time
andrewnitu Nov 2, 2020
055971e
revert changes to xgboost example
andrewnitu Nov 2, 2020
be1db98
add chunking test and clean up tests
andrewnitu Nov 2, 2020
c5a11ae
refactor chunking test
andrewnitu Nov 2, 2020
a36dcf5
revert adding __eq__ method to metric entity
andrewnitu Nov 2, 2020
6aab196
remove commented-out code
andrewnitu Nov 2, 2020
e36549e
fix xgboost autolog tests
andrewnitu Nov 2, 2020
0ec68db
remove unused import
andrewnitu Nov 2, 2020
1cb6057
remove unused import
andrewnitu Nov 3, 2020
b106659
code review
andrewnitu Nov 4, 2020
fd61abd
fix line lenght
andrewnitu Nov 4, 2020
1304367
change to total log batch time instead of average
andrewnitu Nov 4, 2020
c7b41ce
make test go through two cycles of batch logging
andrewnitu Nov 5, 2020
3c18c8b
code review
andrewnitu Nov 5, 2020
1d0c8aa
some code review
andrewnitu Nov 6, 2020
f665e4f
code review
andrewnitu Nov 9, 2020
bd71788
remove extra param from xgboost example
andrewnitu Nov 10, 2020
26dc835
nit fix
andrewnitu Nov 11, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
100 changes: 100 additions & 0 deletions mlflow/utils/autologging_utils.py
@@ -1,9 +1,14 @@
import inspect
import functools
import warnings
import time
import contextlib

import mlflow
from mlflow.utils import gorilla
from mlflow.entities import Metric
from mlflow.tracking.client import MlflowClient
from mlflow.utils.validation import MAX_METRICS_PER_BATCH


INPUT_EXAMPLE_SAMPLE_ROWS = 5
Expand Down Expand Up @@ -187,3 +192,98 @@ def resolve_input_example_and_signature(
logger.warning(model_signature_user_msg)

return input_example if log_input_example else None, model_signature


# wrapper functions to be able to mock this easily in the tests
def time_wrapper_for_log():
return time.time()


def time_wrapper_for_current():
return time.time()


def time_wrapper_for_timestamp():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of wrapping time.time() for mocking, can we just manipulate the total_log_batch_time and total_training_time properties of BatchMetricsHandler in our test cases?

If we're concerned about the measurement of total_training_time and total_log_batch_time, we can always construct another test case that performs sleeps to simulate training / logging and then verifies that total_training_time / total_log_batch_time exceed expected thresholds.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i tried to test this but for some reason at least in pytest sleep doesnt increase the system clock, even though the test is obviously taking longer (so the sleep is running)

return time.time()


# we pass the batch_metrics_handler through, such that the callback can access it
def _timed_log_batch(batch_metrics_handler, run_id, metrics):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we move this into BatchMetricsHandler? Seems to make sense given that the method refers to an instance of batch_metrics_handler.

start = time_wrapper_for_log()
metrics_slices = [
metrics[i * MAX_METRICS_PER_BATCH : (i + 1) * MAX_METRICS_PER_BATCH]
for i in range((len(metrics) + MAX_METRICS_PER_BATCH - 1) // MAX_METRICS_PER_BATCH)
]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think using the step parameter for range() will simplify things here. e.g:

def chunks(lst, n):
    """Yield successive n-sized chunks from lst."""
    for i in range(0, len(lst), n):
        yield lst[i:i + n]

(Credits to https://stackoverflow.com/a/312464/11952869)

for metrics_slice in metrics_slices:
MlflowClient().log_batch(run_id=run_id, metrics=metrics_slice)
end = time_wrapper_for_log()
batch_metrics_handler.total_log_batch_time += end - start
batch_metrics_handler.num_log_batch += 1


class BatchMetricsHandler: # BatchMetricsLogger maybe?
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1. Let's call this BatchMetricsLogger.

def __init__(self):
# data is an array of tuples of the form (timestamp, metrics at timestamp)
self.data = {}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like a dictionary to me! (Though I think it should be a list of Metric objects - see comment below)

self.total_training_time = 0
self.total_log_batch_time = 0
self.num_log_batch = 0
self.previous_training_timestamp = None

def _purge(self):
run_id = mlflow.tracking.fluent._get_or_start_run().info.run_id
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

run_id should be a parameter of BatchMetricsHandler or of record_metrics. I can't think of a case where we'd want BatchMetricsHandler to implicitly start a new run.

For simplicity, I think it makes sense to tie BatchMetricsLogger to a single run_id via a constructor parameter. If you can think of any existing or near-term use cases where adding it to record_metrics would be necessary, please let me know!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, that makes sense. I just needed the run_id and that was my first thought to grab it :D

final_metrics = []

for step, metrics_at_step in self.data.items():
for entry in metrics_at_step:
timestamp = entry[0]
metrics_at_timestamp = entry[1]

for key, value in metrics_at_timestamp.items():
final_metrics.append(Metric(key, value, timestamp, step))

_timed_log_batch(self, run_id=run_id, metrics=final_metrics)

self.data = {}

def _should_purge(self):
if self.num_log_batch == 0:
return True
dbczumar marked this conversation as resolved.
Show resolved Hide resolved

# we give some extra time in case of network slowdown
log_batch_time_fudge_factor = 10
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fudge factor seems like the wrong term here. This is the desired ratio of training time to batch logging time. Perhaps target_training_to_logging_time_ratio?

if (
self.total_training_time
>= self.total_log_batch_time / self.num_log_batch * log_batch_time_fudge_factor
):
return True

return False

# metrics is a dict representing the set of metrics collected during one iteration
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we convert this comment into a docstring?

def record_metrics(self, metrics, step):
current_timestamp = time_wrapper_for_current()
if self.previous_training_timestamp is None:
self.previous_training_timestamp = current_timestamp

training_time = current_timestamp - self.previous_training_timestamp

self.total_training_time += training_time

if step in self.data:
self.data[step].append([int(time_wrapper_for_timestamp() * 1000), metrics])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can time_wrapper_for_timestamp() and these other timing functions give us times in millis so we don't have to convert them?

Copy link
Collaborator

@dbczumar dbczumar Nov 3, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On second thought, I'm not sure we need these timer wrappers. See #3619 (comment)

else:
self.data[step] = [[int(time_wrapper_for_timestamp() * 1000), metrics]]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we construct Metric objects here and just append metric objects to a list, rather than keeping track of things by step? Seems like we ultimately collapse everything into a list at purge time anyway. If we want to maintain a sorted order based on step, timestamp, etc, we can use the sorted function within the purge routine.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah i don't see why not. Seems kinda wasteful to group them by step then ungroup then again as i'm doing now


if self._should_purge():
self.total_training_time = 0
andrewnitu marked this conversation as resolved.
Show resolved Hide resolved
self._purge()

self.previous_training_timestamp = current_timestamp


@contextlib.contextmanager
def with_batch_metrics_handler():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add docs here?

batch_metrics_handler = BatchMetricsHandler()
yield batch_metrics_handler
batch_metrics_handler._purge()
14 changes: 8 additions & 6 deletions mlflow/xgboost.py
Expand Up @@ -45,6 +45,7 @@
resolve_input_example_and_signature,
_InputExampleInfo,
ENSURE_AUTOLOGGING_ENABLED_TEXT,
with_batch_metrics_handler,
)
from mlflow.tracking._model_registry import DEFAULT_AWAIT_MAX_SLEEP_SECONDS

Expand Down Expand Up @@ -340,6 +341,9 @@ def record_eval_results(eval_results):
"""

def callback(env):
batch_metrics_handler.record_metrics(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add batch_metrics_handler as an argument to record_eval_results and thread it through to this callback to ensure that we're not accidentally referencing some state left over from a previous BatchMetricsHandler, for example?

dict(env.evaluation_result_list), env.iteration
)
eval_results.append(dict(env.evaluation_result_list))

return callback
Expand Down Expand Up @@ -423,12 +427,10 @@ def log_feature_importance_plot(features, importance, importance_type):
else:
kwargs["callbacks"] = [callback]

# training model
model = original(*args, **kwargs)

# logging metrics on each iteration.
for idx, metrics in enumerate(eval_results):
try_mlflow_log(mlflow.log_metrics, metrics, step=idx)
# logging metrics on each iteration
with with_batch_metrics_handler() as batch_metrics_handler:
# training model
model = original(*args, **kwargs)

# If early_stopping_rounds is present, logging metrics at the best iteration
# as extra metrics with the max step + 1.
Expand Down
94 changes: 94 additions & 0 deletions tests/utils/test_autologging_utils.py
@@ -1,14 +1,19 @@
import inspect
import pytest
import itertools
from unittest.mock import Mock, call
from unittest import mock


import mlflow
from mlflow.utils import gorilla
from mlflow.tracking.client import MlflowClient
from mlflow.utils.autologging_utils import (
get_unspecified_default_args,
log_fn_args_as_params,
wrap_patch,
resolve_input_example_and_signature,
with_batch_metrics_handler,
)

# Example function signature we are testing on
Expand Down Expand Up @@ -263,3 +268,92 @@ def modifies(_):

assert x["data"] == 0
logger.warning.assert_not_called()


def test_batch_metrics_handler_logs_all_metrics(start_run): # pylint: disable=unused-argument
with mock.patch.object(MlflowClient, "log_batch") as log_batch_mock:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to mock log_batch? Ideally, it would be nice to test that the metrics are actually logged by leaving this unmocked and querying the run data after the iterative calls to record_metrics() complete.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we want to do this for every test? or is it sufficient to do it for one test and know that it works end-to-end, then save time by keeping the mock for the rest of the tests

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually it seems this really only applies to one of the tests.. the rest care about the intermediate state AS the BatchMetricsLogger is logging, not the final outcome, so it makes more sense to use mocks there

with with_batch_metrics_handler() as batch_metrics_handler:
for i in range(100):
batch_metrics_handler.record_metrics({"x": i}, i)

# collect the args of all the logging calls
recorded_metrics = []
for call in log_batch_mock.call_args_list:
_, kwargs = call
metrics_arr = kwargs["metrics"]
for metric in metrics_arr:
recorded_metrics.append({metric._key: metric._value})

desired_metrics = [{"x": i} for i in range(100)]

assert recorded_metrics == desired_metrics


def test_batch_metrics_handler_runs_training_and_logging_in_correct_ratio(
start_run,
): # pylint: disable=unused-argument
with mock.patch.object(MlflowClient, "log_batch") as log_batch_mock, mock.patch(
"mlflow.utils.autologging_utils.time_wrapper_for_log"
) as log_time_mock, mock.patch(
"mlflow.utils.autologging_utils.time_wrapper_for_current"
) as current_time_mock, mock.patch(
"mlflow.utils.autologging_utils.time_wrapper_for_timestamp"
) as timestamp_time_mock:
current_time_mock.side_effect = [
0,
1,
2,
3,
4,
5,
6,
7,
8,
9,
10,
11,
] # training occurs every second
log_time_mock.side_effect = itertools.cycle(
[100, 101]
) # logging takes 1 second, numbers don't matter here
timestamp_time_mock.return_value = 9999 # this doesn't matter
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this number doesn't matter, do we need to mock it?


with with_batch_metrics_handler() as batch_metrics_handler:
batch_metrics_handler.record_metrics({"x": 1}, step=0) # data doesn't matter
# first metrics should be skipped to record a previous timestamp and batch log time
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

skipped seems to imply that we're dropping metrics or not logging them; I think we mean that we're logging them immediately (i.e. "skipping waiting")

log_batch_mock.assert_called_once()

log_batch_mock.reset_mock() # resets the 'calls' of this mock

# the above 'training' took 1 second. So with fudge factor of 10x,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As above, I don't think that fudge factor is accurate terminology for this use case

# 10 more 'training' should happen before the metrics are sent.
for _ in range(9):
batch_metrics_handler.record_metrics({"x": 1}, step=0)
log_batch_mock.assert_not_called()

batch_metrics_handler.record_metrics({"x": 1}, step=0)
log_batch_mock.assert_called_once()


def test_batch_metrics_handler_chunks_metrics_when_batch_logging(
start_run,
): # pylint: disable=unused-argument
with mock.patch.object(MlflowClient, "log_batch") as log_batch_mock, mock.patch(
"mlflow.utils.autologging_utils.time_wrapper_for_timestamp"
) as timestamp_time_mock:
timestamp_time_mock.return_value = 0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this timestamp mock necessary? If not, it seems that we can get rid of time_wrapper_for_timestamp and inline time.time() instead.


with with_batch_metrics_handler() as batch_metrics_handler:
batch_metrics_handler.record_metrics({hex(x): x for x in range(5000)}, step=0)
run_id = mlflow.active_run().info.run_id

for call_idx, call in enumerate(log_batch_mock.call_args_list):
_, kwargs = call

assert kwargs["run_id"] == run_id
assert len(kwargs["metrics"]) == 1000
for metric_idx, metric in enumerate(kwargs["metrics"]):
assert metric.key == hex(call_idx * 1000 + metric_idx)
assert metric.timestamp == 0
assert metric.value == call_idx * 1000 + metric_idx
assert metric.step == 0