Skip to content

Commit

Permalink
Add safety guard for tracing logic to avoid exception (#12127)
Browse files Browse the repository at this point in the history
Signed-off-by: B-Step62 <yuki.watanabe@databricks.com>
  • Loading branch information
B-Step62 committed May 25, 2024
1 parent 8b11d09 commit 24872d1
Show file tree
Hide file tree
Showing 7 changed files with 114 additions and 73 deletions.
10 changes: 10 additions & 0 deletions mlflow/entities/span.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,9 @@ def to_immutable_span(self) -> "Span":
return Span(self._span)


NO_OP_SPAN_REQUEST_ID = "MLFLOW_NO_OP_SPAN_REQUEST_ID"


class NoOpSpan(Span):
"""
No-op implementation of the Span interface.
Expand All @@ -383,6 +386,13 @@ class NoOpSpan(Span):
def __init__(self, *args, **kwargs):
self._attributes = {}

@property
def request_id(self):
"""
No-op span returns a special request ID to distinguish it from the real spans.
"""
return NO_OP_SPAN_REQUEST_ID

@property
def span_id(self):
return None
Expand Down
15 changes: 13 additions & 2 deletions mlflow/tracing/fluent.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,11 @@ def start_span(
InMemoryTraceManager.get_instance().register_span(mlflow_span)

except Exception as e:
_logger.debug("Failed to start span: %s", e, exc_info=True)
_logger.warning(
f"Failed to start span: {e}. ",
"For full traceback, set logging level to debug.",
exc_info=_logger.isEnabledFor(logging.DEBUG),
)
mlflow_span = NoOpSpan()
yield mlflow_span
return
Expand All @@ -214,7 +218,14 @@ def start_span(
with trace_api.use_span(mlflow_span._span, end_on_exit=False):
yield mlflow_span
finally:
mlflow_span.end()
try:
mlflow_span.end()
except Exception as e:
_logger.warning(
f"Failed to end span {mlflow_span.span_id}: {e}. "
"For full traceback, set logging level to debug.",
exc_info=_logger.isEnabledFor(logging.DEBUG),
)


@experimental
Expand Down
41 changes: 10 additions & 31 deletions mlflow/tracing/processor/mlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
from mlflow.tracing.trace_manager import InMemoryTraceManager, _Trace
from mlflow.tracing.utils import (
deduplicate_span_names_in_place,
encode_trace_id,
get_otel_attribute,
maybe_get_dependencies_schemas,
maybe_get_request_id,
Expand Down Expand Up @@ -107,37 +106,17 @@ def _start_trace(self, span: OTelSpan) -> TraceInfo:
tags.update({TraceTagKey.EVAL_REQUEST_ID: request_id})
if depedencies_schema := maybe_get_dependencies_schemas():
tags.update(depedencies_schema)
try:
trace_info = self._client._start_tracked_trace(
experiment_id=experiment_id,
# TODO: This timestamp is not accurate because it is not adjusted to exclude the
# latency of the backend API call. We do this adjustment for span start time
# above, but can't do it for trace start time until the backend API supports
# updating the trace start time.
timestamp_ms=span.start_time // 1_000_000, # nanosecond to millisecond
request_metadata=metadata,
tags=tags,
)

# TODO: This catches all exceptions from the tracking server so the in-memory tracing
# still works if the backend APIs are not ready. Once backend is ready, we should
# catch more specific exceptions and handle them accordingly.
except Exception:
_logger.debug(
"Failed to start a trace in the tracking server. This may be because the "
"backend APIs are not available. Fallback to client-side generation",
exc_info=True,
)
request_id = encode_trace_id(span.context.trace_id)
trace_info = self._create_trace_info(
request_id,
span,
experiment_id,
metadata,
tags=tags,
)

return trace_info
return self._client._start_tracked_trace(
experiment_id=experiment_id,
# TODO: This timestamp is not accurate because it is not adjusted to exclude the
# latency of the backend API call. We do this adjustment for span start time
# above, but can't do it for trace start time until the backend API supports
# updating the trace start time.
timestamp_ms=span.start_time // 1_000_000, # nanosecond to millisecond
request_metadata=metadata,
tags=tags,
)

def on_end(self, span: OTelReadableSpan) -> None:
"""
Expand Down
24 changes: 22 additions & 2 deletions mlflow/tracking/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
)
from mlflow.entities.model_registry import ModelVersion, RegisteredModel
from mlflow.entities.model_registry.model_version_stages import ALL_STAGES
from mlflow.entities.span import LiveSpan, NoOpSpan
from mlflow.entities.span import NO_OP_SPAN_REQUEST_ID, LiveSpan, NoOpSpan
from mlflow.entities.trace_status import TraceStatus
from mlflow.environment_variables import MLFLOW_ENABLE_ASYNC_LOGGING
from mlflow.exceptions import MlflowException
Expand Down Expand Up @@ -635,6 +635,12 @@ def end_trace(
:py:class:`SpanStatusCode <mlflow.entities.SpanStatusCode>`
e.g. ``"OK"``, ``"ERROR"``. The default status is OK.
"""
# NB: If the specified request ID is of no-op span, this means something went wrong in
# the span start logic. We should simply ignore it as the upstream should already
# have logged the error.
if request_id == NO_OP_SPAN_REQUEST_ID:
return

trace_manager = InMemoryTraceManager.get_instance()
root_span_id = trace_manager.get_root_span_id(request_id)

Expand Down Expand Up @@ -792,6 +798,10 @@ def start_span(
client.end_trace(request_id)
"""
# If parent span is no-op span, the child should also be no-op too
if request_id == NO_OP_SPAN_REQUEST_ID:
return NoOpSpan()

if not parent_id:
raise MlflowException(
"start_span() must be called with an explicit parent_id."
Expand Down Expand Up @@ -855,6 +865,9 @@ def end_span(
:py:class:`SpanStatusCode <mlflow.entities.SpanStatusCode>`
e.g. ``"OK"``, ``"ERROR"``. The default status is OK.
"""
if request_id == NO_OP_SPAN_REQUEST_ID:
return

trace_manager = InMemoryTraceManager.get_instance()
span = trace_manager.get_span_from_id(request_id, span_id)

Expand All @@ -870,7 +883,14 @@ def end_span(
span.set_outputs(outputs)
span.set_status(status)

span.end()
try:
span.end()
except Exception as e:
_logger.warning(
f"Failed to end span {span_id}: {e}. "
"For full traceback, set logging level to debug.",
exc_info=_logger.isEnabledFor(logging.DEBUG),
)

def _start_tracked_trace(
self,
Expand Down
38 changes: 1 addition & 37 deletions tests/tracing/processor/test_mlflow_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
)
from mlflow.tracing.processor.mlflow import MlflowSpanProcessor
from mlflow.tracing.trace_manager import InMemoryTraceManager
from mlflow.tracing.utils import encode_trace_id
from mlflow.utils.os import is_windows

from tests.tracing.helper import create_mock_otel_span, create_test_trace_info
Expand Down Expand Up @@ -74,7 +73,7 @@ def test_on_start_adjust_span_timestamp_to_exclude_backend_latency(clear_singlet
trace_info = create_test_trace_info(_REQUEST_ID, 0)
mock_client = mock.MagicMock()

def _mock_start_tracked_trace():
def _mock_start_tracked_trace(*args, **kwargs):
time.sleep(0.5) # Simulate backend latency
return trace_info

Expand Down Expand Up @@ -125,41 +124,6 @@ def test_on_start_with_experiment_id(clear_singleton, monkeypatch):
assert _REQUEST_ID in InMemoryTraceManager.get_instance()._traces


def test_on_start_fallback_to_client_side_request_id(clear_singleton, monkeypatch):
monkeypatch.setenv("MLFLOW_TESTING", "false")
monkeypatch.setattr(mlflow.tracking.context.default_context, "_get_source_name", lambda: "test")
monkeypatch.setenv(MLFLOW_TRACKING_USERNAME.name, "bob")
span = create_mock_otel_span(
trace_id=_TRACE_ID, span_id=1, parent_id=None, start_time=5_000_000
)

mock_client = mock.MagicMock()
mock_client._start_tracked_trace.side_effect = Exception("error")
processor = MlflowSpanProcessor(span_exporter=mock.MagicMock(), client=mock_client)

processor.on_start(span)

mock_client._start_tracked_trace.assert_called_once_with(
experiment_id="0",
timestamp_ms=5,
request_metadata={},
tags={
"mlflow.user": "bob",
"mlflow.source.name": "test",
"mlflow.source.type": "LOCAL",
TRACE_SCHEMA_VERSION_KEY: str(TRACE_SCHEMA_VERSION),
},
)
# When the backend returns an error, the request_id is generated at client side from trace_id
expected_request_id = encode_trace_id(_TRACE_ID)
assert span.attributes.get(SpanAttributeKey.REQUEST_ID) == json.dumps(expected_request_id)
with InMemoryTraceManager.get_instance().get_trace(expected_request_id) as trace:
assert trace.info.experiment_id == "0"
assert trace.info.timestamp_ms == 5
assert trace.info.execution_time_ms is None
assert trace.info.status == TraceStatus.IN_PROGRESS


def test_on_start_during_model_evaluation(clear_singleton):
# Root span should create a new trace on start
span = create_mock_otel_span(trace_id=_TRACE_ID, span_id=1)
Expand Down
20 changes: 19 additions & 1 deletion tests/tracing/test_fluent.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
TraceMetadataKey,
TraceTagKey,
)
from mlflow.tracing.fluent import TRACE_BUFFER
from mlflow.tracing.provider import _get_tracer

from tests.tracing.helper import create_test_trace_info, create_trace, get_first_trace, get_traces

Expand Down Expand Up @@ -349,7 +351,7 @@ def some_operation_raise_error(self, x, y):
assert len(trace.data.spans) == 2


def test_trace_ignore_exception_from_tracing_logic(clear_singleton):
def test_trace_ignore_exception_from_tracing_logic(clear_singleton, monkeypatch):
# This test is to make sure that the main prediction logic is not affected
# by the exception raised by the tracing logic.
class TestModel:
Expand All @@ -365,6 +367,7 @@ def predict(self, x, y):

assert output == 7
assert get_traces() == []
TRACE_BUFFER.clear()

# Exception during inspecting inputs: trace is logged without inputs field
with mock.patch("mlflow.tracing.utils.inspect.signature", side_effect=ValueError("Some error")):
Expand All @@ -374,6 +377,21 @@ def predict(self, x, y):
trace = get_first_trace()
assert trace.info.request_metadata[TraceMetadataKey.INPUTS] == "{}"
assert trace.info.request_metadata[TraceMetadataKey.OUTPUTS] == "7"
TRACE_BUFFER.clear()

# Exception during ending span: trace is not logged
# Mock the span processor's on_end handler to raise an exception
tracer = _get_tracer(__name__)

def _always_fail(*args, **kwargs):
raise ValueError("Some error")

monkeypatch.setattr(tracer.span_processor, "on_end", _always_fail)

output = model.predict(2, 5)
assert output == 7
assert get_traces() == []
TRACE_BUFFER.clear()


def test_start_span_context_manager(clear_singleton):
Expand Down
39 changes: 39 additions & 0 deletions tests/tracking/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
from mlflow.store.tracking import SEARCH_MAX_RESULTS_DEFAULT
from mlflow.store.tracking.sqlalchemy_store import SqlAlchemyStore as SqlAlchemyTrackingStore
from mlflow.tracing.constant import TRACE_SCHEMA_VERSION, TRACE_SCHEMA_VERSION_KEY, TraceMetadataKey
from mlflow.tracing.fluent import TRACE_BUFFER
from mlflow.tracing.provider import _get_tracer
from mlflow.tracing.trace_manager import InMemoryTraceManager
from mlflow.tracking import set_registry_uri
from mlflow.tracking._model_registry import DEFAULT_AWAIT_MAX_SLEEP_SECONDS
Expand Down Expand Up @@ -673,6 +675,43 @@ def test_start_span_raise_error_when_parent_id_is_not_provided():
mlflow.tracking.MlflowClient().start_span("span_name", request_id="test", parent_id=None)


def test_ignore_exception_from_tracing_logic(clear_singleton, monkeypatch):
exp_id = mlflow.set_experiment("test_experiment_1").experiment_id
client = MlflowClient()
TRACE_BUFFER.clear()

class TestModel:
def predict(self, x):
root_span = client.start_trace(experiment_id=exp_id, name="predict")
request_id = root_span.request_id
child_span = client.start_span(
name="child", request_id=request_id, parent_id=root_span.span_id
)
client.end_span(request_id, child_span.span_id)
client.end_trace(request_id)
return x

model = TestModel()

# Mock the span processor's on_end handler to raise an exception
processor = _get_tracer(__name__).span_processor

def _always_fail(*args, **kwargs):
raise ValueError("Some error")

# Exception while starting the trace should be caught not raise
monkeypatch.setattr(processor, "on_start", _always_fail)
response = model.predict(1)
assert response == 1
assert len(TRACE_BUFFER) == 0

# Exception while ending the trace should be caught not raise
monkeypatch.setattr(processor, "on_end", _always_fail)
response = model.predict(1)
assert response == 1
assert len(TRACE_BUFFER) == 0


def test_set_and_delete_trace_tag_on_active_trace(clear_singleton, monkeypatch):
monkeypatch.setenv(MLFLOW_TRACKING_USERNAME.name, "bob")
monkeypatch.setattr(mlflow.tracking.context.default_context, "_get_source_name", lambda: "test")
Expand Down

0 comments on commit 24872d1

Please sign in to comment.