Skip to content

Commit

Permalink
Introduce time-based batch metrics logging and change XGBoost to use …
Browse files Browse the repository at this point in the history
…it (#3619)

* xgboost log on every iteration with timing

Signed-off-by: Andrew Nitu <andrewnitu@gmail.com>

* get avg time

Signed-off-by: Andrew Nitu <andrewnitu@gmail.com>

* fix

Signed-off-by: Andrew Nitu <andrewnitu@gmail.com>

* batch send all at the end of training

Signed-off-by: Andrew Nitu <andrewnitu@gmail.com>

* stash

Signed-off-by: Andrew Nitu <andrewnitu@gmail.com>

* rename promise to future

Signed-off-by: Andrew Nitu <andrewnitu@gmail.com>

* remove batch_log_interval

Signed-off-by: Andrew Nitu <andrewnitu@gmail.com>

* make should_purge have no side effects

Signed-off-by: Andrew Nitu <andrewnitu@gmail.com>

* do not assume step anymore

Signed-off-by: Andrew Nitu <andrewnitu@gmail.com>

* add test case

Signed-off-by: Andrew Nitu <andrewnitu@gmail.com>

* stash

Signed-off-by: Andrew Nitu <andrewnitu@gmail.com>

* autofmt

Signed-off-by: Andrew Nitu <andrewnitu@gmail.com>

* linting

Signed-off-by: Andrew Nitu <andrewnitu@gmail.com>

* some cleanup and gather batch log time on initial iteration

Signed-off-by: Andrew Nitu <andrewnitu@gmail.com>

* more cleanup

Signed-off-by: Andrew Nitu <andrewnitu@gmail.com>

* reimport time

Signed-off-by: Andrew Nitu <andrewnitu@gmail.com>

* revert changes to xgboost example

Signed-off-by: Andrew Nitu <andrewnitu@gmail.com>

* add chunking test and clean up tests

Signed-off-by: Andrew Nitu <andrewnitu@gmail.com>

* refactor chunking test

Signed-off-by: Andrew Nitu <andrewnitu@gmail.com>

* revert adding __eq__ method to metric entity

Signed-off-by: Andrew Nitu <andrewnitu@gmail.com>

* remove commented-out code

Signed-off-by: Andrew Nitu <andrewnitu@gmail.com>

* fix xgboost autolog tests

Signed-off-by: Andrew Nitu <andrewnitu@gmail.com>

* remove unused import

Signed-off-by: Andrew Nitu <andrewnitu@gmail.com>

* remove unused import

Signed-off-by: Andrew Nitu <andrewnitu@gmail.com>

* code review

Signed-off-by: Andrew Nitu <andrewnitu@gmail.com>

* fix line lenght

Signed-off-by: Andrew Nitu <andrewnitu@gmail.com>

* change to total log batch time instead of average

Signed-off-by: Andrew Nitu <andrewnitu@gmail.com>

* make test go through two cycles of batch logging

Signed-off-by: Andrew Nitu <andrewnitu@gmail.com>

* code review

Signed-off-by: Andrew Nitu <andrewnitu@gmail.com>

* some code review

Signed-off-by: Andrew Nitu <andrewnitu@gmail.com>

* code review

Signed-off-by: Andrew Nitu <andrewnitu@gmail.com>

* remove extra param from xgboost example

Signed-off-by: Andrew Nitu <andrewnitu@gmail.com>

* nit fix

Signed-off-by: Andrew Nitu <andrewnitu@gmail.com>
  • Loading branch information
andrewnitu committed Nov 11, 2020
1 parent 21fdf64 commit 8eed7c4
Show file tree
Hide file tree
Showing 3 changed files with 246 additions and 16 deletions.
88 changes: 88 additions & 0 deletions mlflow/utils/autologging_utils.py
@@ -1,9 +1,14 @@
import inspect
import functools
import warnings
import time
import contextlib

import mlflow
from mlflow.utils import gorilla
from mlflow.entities import Metric
from mlflow.tracking.client import MlflowClient
from mlflow.utils.validation import MAX_METRICS_PER_BATCH


INPUT_EXAMPLE_SAMPLE_ROWS = 5
Expand Down Expand Up @@ -187,3 +192,86 @@ def resolve_input_example_and_signature(
logger.warning(model_signature_user_msg)

return input_example if log_input_example else None, model_signature


class BatchMetricsLogger:
def __init__(self, run_id):
self.run_id = run_id

# data is an array of Metric objects
self.data = []
self.total_training_time = 0
self.total_log_batch_time = 0
self.previous_training_timestamp = None

def _purge(self):
self._timed_log_batch()
self.data = []

def _timed_log_batch(self):
start = time.time()
metrics_slices = [
self.data[i : i + MAX_METRICS_PER_BATCH]
for i in range(0, len(self.data), MAX_METRICS_PER_BATCH)
]
for metrics_slice in metrics_slices:
try_mlflow_log(MlflowClient().log_batch, run_id=self.run_id, metrics=metrics_slice)
end = time.time()
self.total_log_batch_time += end - start

def _should_purge(self):
target_training_to_logging_time_ratio = 10
if (
self.total_training_time
>= self.total_log_batch_time * target_training_to_logging_time_ratio
):
return True

return False

def record_metrics(self, metrics, step):
"""
Submit a set of metrics to be logged. The metrics may not be immediately logged, as this
class will batch them in order to not increase execution time too much by logging
frequently.
:param metrics: dictionary containing key, value pairs of metrics to be logged.
:param step: the training step that the metrics correspond to.
"""
current_timestamp = time.time()
if self.previous_training_timestamp is None:
self.previous_training_timestamp = current_timestamp

training_time = current_timestamp - self.previous_training_timestamp

self.total_training_time += training_time

for key, value in metrics.items():
self.data.append(Metric(key, value, int(current_timestamp * 1000), step))

if self._should_purge():
self._purge()

self.previous_training_timestamp = current_timestamp


@contextlib.contextmanager
def batch_metrics_logger(run_id):
"""
Context manager that yields a BatchMetricsLogger object, which metrics can be logged against.
The BatchMetricsLogger keeps metrics in a list until it decides they should be logged, at
which point the accumulated metrics will be batch logged. The BatchMetricsLogger ensures
that logging imposes no more than a 10% overhead on the training, where the training is
measured by adding up the time elapsed between consecutive calls to record_metrics.
If logging a batch fails, a warning will be emitted and subsequent metrics will continue to
be collected.
Once the context is closed, any metrics that have yet to be logged will be logged.
:param run_id: ID of the run that the metrics will be logged to.
"""

batch_metrics_logger = BatchMetricsLogger(run_id)
yield batch_metrics_logger
batch_metrics_logger._purge()
33 changes: 17 additions & 16 deletions mlflow/xgboost.py
Expand Up @@ -45,6 +45,7 @@
resolve_input_example_and_signature,
_InputExampleInfo,
ENSURE_AUTOLOGGING_ENABLED_TEXT,
batch_metrics_logger,
)
from mlflow.tracking._model_registry import DEFAULT_AWAIT_MAX_SLEEP_SECONDS

Expand Down Expand Up @@ -337,12 +338,13 @@ def __init__(self, *args, **kwargs):
original(self, *args, **kwargs)

def train(*args, **kwargs):
def record_eval_results(eval_results):
def record_eval_results(eval_results, metrics_logger):
"""
Create a callback function that records evaluation results.
"""

def callback(env):
metrics_logger.record_metrics(dict(env.evaluation_result_list), env.iteration)
eval_results.append(dict(env.evaluation_result_list))

return callback
Expand Down Expand Up @@ -416,22 +418,21 @@ def log_feature_importance_plot(features, importance, importance_type):
# adding a callback that records evaluation results.
eval_results = []
callbacks_index = all_arg_names.index("callbacks")
callback = record_eval_results(eval_results)
if num_pos_args >= callbacks_index + 1:
tmp_list = list(args)
tmp_list[callbacks_index] += [callback]
args = tuple(tmp_list)
elif "callbacks" in kwargs and kwargs["callbacks"] is not None:
kwargs["callbacks"] += [callback]
else:
kwargs["callbacks"] = [callback]

# training model
model = original(*args, **kwargs)

# logging metrics on each iteration.
for idx, metrics in enumerate(eval_results):
try_mlflow_log(mlflow.log_metrics, metrics, step=idx)
run_id = mlflow.active_run().info.run_id
with batch_metrics_logger(run_id) as metrics_logger:
callback = record_eval_results(eval_results, metrics_logger)
if num_pos_args >= callbacks_index + 1:
tmp_list = list(args)
tmp_list[callbacks_index] += [callback]
args = tuple(tmp_list)
elif "callbacks" in kwargs and kwargs["callbacks"] is not None:
kwargs["callbacks"] += [callback]
else:
kwargs["callbacks"] = [callback]

# training model
model = original(*args, **kwargs)

# If early_stopping_rounds is present, logging metrics at the best iteration
# as extra metrics with the max step + 1.
Expand Down
141 changes: 141 additions & 0 deletions tests/utils/test_autologging_utils.py
@@ -1,14 +1,19 @@
import inspect
import time
import pytest
from unittest.mock import Mock, call
from unittest import mock


import mlflow
from mlflow.utils import gorilla
from mlflow.tracking.client import MlflowClient
from mlflow.utils.autologging_utils import (
get_unspecified_default_args,
log_fn_args_as_params,
wrap_patch,
resolve_input_example_and_signature,
batch_metrics_logger,
)

# Example function signature we are testing on
Expand Down Expand Up @@ -263,3 +268,139 @@ def modifies(_):

assert x["data"] == 0
logger.warning.assert_not_called()


def test_batch_metrics_logger_logs_all_metrics(start_run,): # pylint: disable=unused-argument
run_id = mlflow.active_run().info.run_id
with batch_metrics_logger(run_id) as metrics_logger:
for i in range(100):
metrics_logger.record_metrics({hex(i): i}, i)

metrics_on_run = mlflow.tracking.MlflowClient().get_run(run_id).data.metrics

for i in range(100):
assert hex(i) in metrics_on_run
assert metrics_on_run[hex(i)] == i


def test_batch_metrics_logger_runs_training_and_logging_in_correct_ratio(
start_run,
): # pylint: disable=unused-argument
with mock.patch.object(MlflowClient, "log_batch") as log_batch_mock:
run_id = mlflow.active_run().info.run_id
with batch_metrics_logger(run_id) as metrics_logger:
metrics_logger.record_metrics({"x": 1}, step=0) # data doesn't matter

# first metrics should be logged immediately to record a previous timestamp and
# batch log time
log_batch_mock.assert_called_once()

metrics_logger.total_log_batch_time = 1
metrics_logger.total_training_time = 1

log_batch_mock.reset_mock() # resets the 'calls' of this mock

# the above 'training' took 1 second. So with target training-to-logging time ratio of
# 10:1, 9 more 'training' should happen without sending the batch and then after the
# 10th training the batch should be sent.
for i in range(2, 11):
metrics_logger.record_metrics({"x": 1}, step=0)
log_batch_mock.assert_not_called()
metrics_logger.total_training_time = i

# at this point, average log batch time is 1, and total training time is 9
# thus the next record_metrics call should send the batch.
metrics_logger.record_metrics({"x": 1}, step=0)
log_batch_mock.assert_called_once()

# update log_batch time to reflect the 'mocked' training time
metrics_logger.total_log_batch_time = 2

log_batch_mock.reset_mock() # reset the recorded calls

for i in range(12, 21):
metrics_logger.record_metrics({"x": 1}, step=0)
log_batch_mock.assert_not_called()
metrics_logger.total_training_time = i

metrics_logger.record_metrics({"x": 1}, step=0)
log_batch_mock.assert_called_once()


def test_batch_metrics_logger_chunks_metrics_when_batch_logging(
start_run,
): # pylint: disable=unused-argument
with mock.patch.object(MlflowClient, "log_batch") as log_batch_mock:
run_id = mlflow.active_run().info.run_id
with batch_metrics_logger(run_id) as metrics_logger:
metrics_logger.record_metrics({hex(x): x for x in range(5000)}, step=0)
run_id = mlflow.active_run().info.run_id

for call_idx, call in enumerate(log_batch_mock.call_args_list):
_, kwargs = call

assert kwargs["run_id"] == run_id
assert len(kwargs["metrics"]) == 1000
for metric_idx, metric in enumerate(kwargs["metrics"]):
assert metric.key == hex(call_idx * 1000 + metric_idx)
assert metric.value == call_idx * 1000 + metric_idx
assert metric.step == 0


def test_batch_metrics_logger_records_time_correctly(start_run,): # pylint: disable=unused-argument
with mock.patch.object(MlflowClient, "log_batch", wraps=lambda *args, **kwargs: time.sleep(1)):
run_id = mlflow.active_run().info.run_id
with batch_metrics_logger(run_id) as metrics_logger:
metrics_logger.record_metrics({"x": 1}, step=0)

assert metrics_logger.total_log_batch_time >= 1

time.sleep(2)

metrics_logger.record_metrics({"x": 1}, step=0)

assert metrics_logger.total_training_time >= 2


def test_batch_metrics_logger_logs_timestamps_as_int_milliseconds(
start_run,
): # pylint: disable=unused-argument
with mock.patch.object(MlflowClient, "log_batch") as log_batch_mock, mock.patch(
"time.time", return_value=123.45678901234567890
):
run_id = mlflow.active_run().info.run_id
with batch_metrics_logger(run_id) as metrics_logger:
metrics_logger.record_metrics({"x": 1}, step=0)

_, kwargs = log_batch_mock.call_args

logged_metric = kwargs["metrics"][0]

assert logged_metric.timestamp == 123456


def test_batch_metrics_logger_continues_if_log_batch_fails(
start_run,
): # pylint: disable=unused-argument
with mock.patch.object(MlflowClient, "log_batch") as log_batch_mock:
log_batch_mock.side_effect = [Exception("asdf"), None]

run_id = mlflow.active_run().info.run_id
with batch_metrics_logger(run_id) as metrics_logger:
# this call should fail to record since log_batch raised exception
metrics_logger.record_metrics({"x": 1}, step=0)

metrics_logger.record_metrics({"y": 2}, step=1)

# even though the first call to log_batch failed, the BatchMetricsLogger should continue
# logging subsequent batches
last_call = log_batch_mock.call_args_list[-1]

_, kwargs = last_call

assert kwargs["run_id"] == run_id
assert len(kwargs["metrics"]) == 1
metric = kwargs["metrics"][0]
assert metric.key == "y"
assert metric.value == 2
assert metric.step == 1

0 comments on commit 8eed7c4

Please sign in to comment.