Skip to content

Commit

Permalink
Support async logging per X seconds (#12324)
Browse files Browse the repository at this point in the history
Signed-off-by: chenmoneygithub <chen.qian@databricks.com>
  • Loading branch information
chenmoneygithub committed Jun 12, 2024
1 parent 91cde6e commit c5ae48e
Show file tree
Hide file tree
Showing 6 changed files with 132 additions and 19 deletions.
6 changes: 6 additions & 0 deletions mlflow/environment_variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,3 +621,9 @@ def get(self):
#: variable sets the `pool_maxsize` parameter in the `requests.adapters.HTTPAdapter` constructor.
#: By adjusting this variable, users can enhance the concurrency of HTTP requests made by MLflow.
MLFLOW_HTTP_POOL_MAXSIZE = _EnvironmentVariable("MLFLOW_HTTP_POOL_MAXSIZE", int, 10)

#: Specifies the length of time in seconds for the asynchronous logging thread to wait before
#: logging a batch.
MLFLOW_ASYNC_LOGGING_BUFFERING_SECONDS = _EnvironmentVariable(
"MLFLOW_ASYNC_LOGGING_BUFFERING_SECONDS", int, None
)
9 changes: 9 additions & 0 deletions mlflow/store/tracking/abstract_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -618,6 +618,15 @@ def log_batch_async(self, run_id, metrics, params, tags) -> RunOperations:
run_id=run_id, metrics=metrics, params=params, tags=tags
)

def end_async_logging(self):
"""
Ends the async logging queue. This method is a no-op if the queue is not active. This is
different from flush as it just stops the async logging queue from accepting new data, but
flush will ensure all data is processed before returning.
"""
if self._async_logging_queue.is_active():
self._async_logging_queue.end_async_logging()

def flush_async_logging(self):
"""
Flushes the async logging queue. This method is a no-op if the queue is not active.
Expand Down
4 changes: 4 additions & 0 deletions mlflow/tracking/_tracking_service/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -884,6 +884,10 @@ def set_terminated(self, run_id, status=None, end_time=None):
"""
end_time = end_time if end_time else get_current_time_millis()
status = status if status else RunStatus.to_string(RunStatus.FINISHED)
# Tell the store to stop async logging: stop accepting new data and log already enqueued
# data in the background. This call is making sure every async logging data has been
# submitted for logging, but not necessarily finished logging.
self.store.end_async_logging()
self.store.update_run_info(
run_id,
run_status=RunStatus.from_string(status),
Expand Down
71 changes: 63 additions & 8 deletions mlflow/utils/async_logging/async_logging_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,15 @@
import threading
from concurrent.futures import ThreadPoolExecutor
from queue import Empty, Queue
from typing import List

from mlflow.entities.metric import Metric
from mlflow.entities.param import Param
from mlflow.entities.run_tag import RunTag
from mlflow.environment_variables import MLFLOW_ASYNC_LOGGING_THREADPOOL_SIZE
from mlflow.environment_variables import (
MLFLOW_ASYNC_LOGGING_BUFFERING_SECONDS,
MLFLOW_ASYNC_LOGGING_THREADPOOL_SIZE,
)
from mlflow.utils.async_logging.run_batch import RunBatch
from mlflow.utils.async_logging.run_operations import RunOperations

Expand Down Expand Up @@ -56,15 +60,20 @@ def _at_exit_callback(self) -> None:
except Exception as e:
_logger.error(f"Encountered error while trying to finish logging: {e}")

def end_async_logging(self) -> None:
with self._lock:
# Stop the data processing thread.
self._stop_data_logging_thread_event.set()
# Waits till logging queue is drained.
self._batch_logging_thread.join()
self._is_activated = False

def flush(self) -> None:
"""Flush the async logging queue.
Calling this method will flush the queue to ensure all the data are logged.
"""
# Stop the data processing thread.
self._stop_data_logging_thread_event.set()
# Waits till logging queue is drained.
self._batch_logging_thread.join()
self.end_async_logging()
self._batch_logging_worker_threadpool.shutdown(wait=True)
self._batch_status_check_threadpool.shutdown(wait=True)

Expand All @@ -88,6 +97,40 @@ def _logging_loop(self) -> None:

raise MlflowException(f"Exception inside the run data logging thread: {e}")

def _fetch_batch_from_queue(self) -> List[RunBatch]:
"""Fetches a batch of run data from the queue.
Returns:
RunBatch: A batch of run data.
"""
batches = []
if self._queue.empty():
return batches
queue_size = self._queue.qsize() # Estimate the queue's size.
merged_batch = self._queue.get()
for i in range(queue_size - 1):
if self._queue.empty():
# `queue_size` is an estimate, so we need to check if the queue is empty.
break
batch = self._queue.get()
if (
merged_batch.run_id != batch.run_id
or len(merged_batch.metrics) + len(batch.metrics) >= 1000
or len(merged_batch.params) + len(batch.params) >= 100
or len(merged_batch.tags) + len(batch.tags) >= 100
):
# Make a new batch if the run_id is different or the batch is full.
batches.append(merged_batch)
merged_batch = batch
else:
merged_batch.add_child_batch(batch)
merged_batch.params.extend(batch.params)
merged_batch.tags.extend(batch.tags)
merged_batch.metrics.extend(batch.metrics)

batches.append(merged_batch)
return batches

def _log_run_data(self) -> None:
"""Process the run data in the running runs queues.
Expand All @@ -100,9 +143,13 @@ def _log_run_data(self) -> None:
Returns: None
"""
run_batch = None # type: RunBatch
async_logging_buffer_seconds = MLFLOW_ASYNC_LOGGING_BUFFERING_SECONDS.get()
try:
run_batch = self._queue.get(timeout=1)
if async_logging_buffer_seconds:
self._stop_data_logging_thread_event.wait(async_logging_buffer_seconds)
run_batches = self._fetch_batch_from_queue()
else:
run_batches = [self._queue.get(timeout=1)]
except Empty:
# Ignore empty queue exception
return
Expand All @@ -118,13 +165,20 @@ def logging_func(run_batch):

# Signal the batch processing is done.
run_batch.completion_event.set()
for child_batch in run_batch.child_batches:
# Signal the child batch processing is done.
child_batch.completion_event.set()

except Exception as e:
_logger.error(f"Run Id {run_batch.run_id}: Failed to log run data: Exception: {e}")
run_batch.exception = e
run_batch.completion_event.set()
for child_batch in run_batch.child_batches:
# Signal the child batch processing is done.
child_batch.completion_event.set()

self._batch_logging_worker_threadpool.submit(logging_func, run_batch)
for run_batch in run_batches:
self._batch_logging_worker_threadpool.submit(logging_func, run_batch)

def _wait_for_batch(self, batch: RunBatch) -> None:
"""Wait for the given batch to be processed by the logging thread.
Expand Down Expand Up @@ -242,6 +296,7 @@ def _set_up_logging_thread(self) -> None:
max_workers=MLFLOW_ASYNC_LOGGING_THREADPOOL_SIZE.get() or 10,
thread_name_prefix="MLflowAsyncLoggingStatusCheck",
)

self._batch_logging_thread.start()

def activate(self) -> None:
Expand Down
29 changes: 19 additions & 10 deletions mlflow/utils/async_logging/run_batch.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import threading
from typing import List
from typing import List, Optional

from mlflow.entities.metric import Metric
from mlflow.entities.param import Param
Expand All @@ -10,26 +10,27 @@ class RunBatch:
def __init__(
self,
run_id: str,
params: List[Param],
tags: List[RunTag],
metrics: List[Metric],
completion_event: threading.Event,
) -> None:
params: Optional[List["Param"]] = None,
tags: Optional[List["RunTag"]] = None,
metrics: Optional[List["Metric"]] = None,
completion_event: Optional[threading.Event] = None,
):
"""Initializes an instance of `RunBatch`.
Args:
run_id: The ID of the run.
params: A list of parameters.
tags: A list of tags.
metrics: A list of metrics.
completion_event: A threading.Event object.
params: A list of parameters. Default is None.
tags: A list of tags. Default is None.
metrics: A list of metrics. Default is None.
completion_event: A threading.Event object. Default is None.
"""
self.run_id = run_id
self.params = params or []
self.tags = tags or []
self.metrics = metrics or []
self.completion_event = completion_event
self._exception = None
self.child_batches = []

@property
def exception(self):
Expand All @@ -39,3 +40,11 @@ def exception(self):
@exception.setter
def exception(self, exception):
self._exception = exception

def add_child_batch(self, child_batch):
"""Add a child batch to the current batch.
This is useful when merging child batches into a parent batch. Child batches are kept so
that we can properly notify the system when child batches have been processed.
"""
self.child_batches.append(child_batch)
32 changes: 31 additions & 1 deletion tests/utils/test_async_logging_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import threading
import time
import uuid
from unittest.mock import MagicMock, patch

import pytest

Expand Down Expand Up @@ -42,7 +43,36 @@ def consume_queue_data(self, run_id, metrics, tags, params):
self.received_tags.extend(tags or [])


def test_single_thread_publish_consume_queue():
def test_single_thread_publish_consume_queue(monkeypatch):
monkeypatch.setenv("MLFLOW_ASYNC_LOGGING_BUFFERING_SECONDS", "3")

with patch.object(
AsyncLoggingQueue, "_batch_logging_worker_threadpool", create=True
) as mock_worker_threadpool, patch.object(
AsyncLoggingQueue, "_batch_status_check_threadpool", create=True
) as mock_check_threadpool:
mock_worker_threadpool.submit = MagicMock()
mock_check_threadpool.submit = MagicMock()
mock_worker_threadpool.shutdown = MagicMock()
mock_check_threadpool.shutdown = MagicMock()

run_id = "test_run_id"
run_data = RunData()
async_logging_queue = AsyncLoggingQueue(run_data.consume_queue_data)
async_logging_queue.activate()
async_logging_queue._batch_logging_worker_threadpool = mock_worker_threadpool
async_logging_queue._batch_status_check_threadpool = mock_check_threadpool

for params, tags, metrics in _get_run_data():
async_logging_queue.log_batch_async(
run_id=run_id, metrics=metrics, tags=tags, params=params
)
async_logging_queue.flush()
# 2 batches are sent to the worker thread pool due to grouping, otherwise it would be 5.
assert mock_worker_threadpool.submit.call_count == 2


def test_grouping_batch_in_time_window():
run_id = "test_run_id"
run_data = RunData()
async_logging_queue = AsyncLoggingQueue(run_data.consume_queue_data)
Expand Down

0 comments on commit c5ae48e

Please sign in to comment.