Skip to content

Commit

Permalink
Add get_last_active_trace() API (#12207)
Browse files Browse the repository at this point in the history
Signed-off-by: B-Step62 <yuki.watanabe@databricks.com>
  • Loading branch information
B-Step62 committed Jun 3, 2024
1 parent 667d7a7 commit 72df4a2
Show file tree
Hide file tree
Showing 9 changed files with 106 additions and 28 deletions.
2 changes: 2 additions & 0 deletions mlflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@
from mlflow.projects import run
from mlflow.tracing.fluent import (
get_current_active_span,
get_last_active_trace,
get_trace,
search_traces,
start_span,
Expand Down Expand Up @@ -182,6 +183,7 @@
"get_artifact_uri",
"get_experiment",
"get_experiment_by_name",
"get_last_active_trace",
"get_parent_run",
"get_registry_uri",
"get_run",
Expand Down
62 changes: 62 additions & 0 deletions mlflow/tracing/fluent.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
MLFLOW_TRACE_BUFFER_TTL_SECONDS,
)
from mlflow.exceptions import MlflowException
from mlflow.protos.databricks_pb2 import BAD_REQUEST
from mlflow.store.tracking import SEARCH_TRACES_DEFAULT_MAX_RESULTS
from mlflow.tracing import provider
from mlflow.tracing.constant import SpanAttributeKey
Expand All @@ -34,6 +35,7 @@
from mlflow.tracking.fluent import _get_experiment_id
from mlflow.utils import get_results_from_paginated_fn
from mlflow.utils.annotations import experimental
from mlflow.utils.databricks_utils import is_in_databricks_model_serving_environment

_logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -421,3 +423,63 @@ def f():
trace_manager = InMemoryTraceManager.get_instance()
request_id = json.loads(otel_span.attributes.get(SpanAttributeKey.REQUEST_ID))
return trace_manager.get_span_from_id(request_id, encode_span_id(otel_span.context.span_id))


@experimental
def get_last_active_trace() -> Optional[Trace]:
"""
Get the last active trace in the same process if exists.
.. warning::
This function DOES NOT work in the model deployed in Databricks model serving.
.. note::
The last active trace is only stored in-memory for the time defined by the TTL
(Time To Live) configuration. By default, the TTL is 1 hour and can be configured
using the environment variable ``MLFLOW_TRACE_BUFFER_TTL_SECONDS``.
.. note::
This function returns an immutable copy of the original trace that is logged
in the tracking store. Any changes made to the returned object will not be reflected
in the original trace. To modify the already ended trace (while most of the data is
immutable after the trace is ended, you can still edit some fields such as `tags`),
please use the respective MlflowClient APIs with the request ID of the trace, as
shown in the example below.
.. code-block:: python
:test:
import mlflow
@mlflow.trace
def f():
pass
f()
trace = mlflow.get_last_active_trace()
# Use MlflowClient APIs to mutate the ended trace
mlflow.MlflowClient().set_trace_tag(trace.info.request_id, "key", "value")
Returns:
The last active trace if exists, otherwise None.
"""
if is_in_databricks_model_serving_environment():
raise MlflowException(
"The function `mlflow.get_last_active_trace` is not supported in "
"Databricks model serving.",
error_code=BAD_REQUEST,
)

if len(TRACE_BUFFER) > 0:
last_active_request_id = list(TRACE_BUFFER.keys())[-1]
return TRACE_BUFFER.get(last_active_request_id)
else:
return None
6 changes: 3 additions & 3 deletions tests/entities/test_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from mlflow.utils.mlflow_tags import MLFLOW_ARTIFACT_LOCATION

from tests.tracing.conftest import clear_singleton # noqa: F401
from tests.tracing.helper import create_test_trace_info, get_first_trace
from tests.tracing.helper import create_test_trace_info


def _test_model(datetime=datetime.now()):
Expand Down Expand Up @@ -49,7 +49,7 @@ def test_json_deserialization(clear_singleton, monkeypatch):
model = _test_model(datetime_now)
model.predict(2, 5)

trace = get_first_trace()
trace = mlflow.get_last_active_trace()
trace_json = trace.to_json()

trace_json_as_dict = json.loads(trace_json)
Expand Down Expand Up @@ -194,7 +194,7 @@ def test_trace_to_from_dict_and_json(clear_singleton):
model = _test_model()
model.predict(2, 5)

trace = get_first_trace()
trace = mlflow.get_last_active_trace()
trace_dict = trace.to_dict()
trace_from_dict = Trace.from_dict(trace_dict)
trace_json = trace.to_json()
Expand Down
3 changes: 1 addition & 2 deletions tests/entities/test_trace_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from mlflow.entities.span_event import SpanEvent

from tests.tracing.conftest import clear_singleton # noqa: F401
from tests.tracing.helper import get_first_trace


def test_json_deserialization(clear_singleton):
Expand All @@ -32,7 +31,7 @@ def always_fail(self):
with pytest.raises(Exception, match="Error!"):
model.predict(2, 5)

trace = get_first_trace()
trace = mlflow.get_last_active_trace()
trace_data = trace.data

# Compare events separately as it includes exception stacktrace which is hard to hardcode
Expand Down
4 changes: 2 additions & 2 deletions tests/langchain/test_langchain_autolog.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@

# TODO: This test helper is used outside the tracing module, we should move it to a common utils
from tests.tracing.conftest import clear_singleton as clear_trace_singleton # noqa: F401
from tests.tracing.helper import get_first_trace, get_traces
from tests.tracing.helper import get_traces

MODEL_DIR = "model"
TEST_CONTENT = "test"
Expand Down Expand Up @@ -975,5 +975,5 @@ def test_set_retriever_schema_work_for_langchain_model(clear_trace_singleton):
pyfunc_model = mlflow.pyfunc.load_model(model_info.model_uri)
pyfunc_model.predict("MLflow")

trace = get_first_trace()
trace = mlflow.get_last_active_trace()
assert DependenciesSchemasType.RETRIEVERS.value in trace.info.tags
12 changes: 6 additions & 6 deletions tests/langchain/test_langchain_tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
)

from tests.tracing.conftest import clear_singleton # noqa: F401
from tests.tracing.helper import get_first_trace, get_traces
from tests.tracing.helper import get_traces

TEST_CONTENT = "test"

Expand Down Expand Up @@ -127,7 +127,7 @@ def test_llm_success(clear_singleton):
callback.on_llm_new_token("test", run_id=run_id)

callback.on_llm_end(LLMResult(generations=[[{"text": "generated text"}]]), run_id=run_id)
trace = get_first_trace()
trace = mlflow.get_last_active_trace()
assert len(trace.data.spans) == 1
llm_span = trace.data.spans[0]

Expand Down Expand Up @@ -159,7 +159,7 @@ def test_llm_error(clear_singleton):
mock_error = Exception("mock exception")
callback.on_llm_error(error=mock_error, run_id=run_id)

trace = get_first_trace()
trace = mlflow.get_last_active_trace()
error_event = SpanEvent.from_exception(mock_error)
assert len(trace.data.spans) == 1
llm_span = trace.data.spans[0]
Expand Down Expand Up @@ -212,7 +212,7 @@ def test_retriever_success(clear_singleton):
),
]
callback.on_retriever_end(documents, run_id=run_id)
trace = get_first_trace()
trace = mlflow.get_last_active_trace()
assert len(trace.data.spans) == 1
retriever_span = trace.data.spans[0]

Expand All @@ -238,7 +238,7 @@ def test_retriever_error(clear_singleton):
)
mock_error = Exception("mock exception")
callback.on_retriever_error(error=mock_error, run_id=run_id)
trace = get_first_trace()
trace = mlflow.get_last_active_trace()
assert len(trace.data.spans) == 1
retriever_span = trace.data.spans[0]
assert retriever_span.attributes[SpanAttributeKey.INPUTS] == "test query"
Expand Down Expand Up @@ -322,7 +322,7 @@ def test_multiple_components(clear_singleton):
outputs={"output": "test output"},
run_id=chain_run_id,
)
trace = get_first_trace()
trace = mlflow.get_last_active_trace()
assert len(trace.data.spans) == 5
chain_span = trace.data.spans[0]
assert chain_span.start_time_ns is not None
Expand Down
5 changes: 0 additions & 5 deletions tests/tracing/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,11 +118,6 @@ def get_traces() -> List[Trace]:
return list(TRACE_BUFFER.values())


def get_first_trace() -> Optional[Trace]:
if traces := get_traces():
return traces[0]


def get_tracer_tracking_uri() -> Optional[str]:
"""Get current tracking URI configured as the trace export destination."""
from opentelemetry import trace
Expand Down
33 changes: 27 additions & 6 deletions tests/tracing/test_fluent.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from mlflow.tracing.fluent import TRACE_BUFFER
from mlflow.tracing.provider import _get_tracer

from tests.tracing.helper import create_test_trace_info, create_trace, get_first_trace, get_traces
from tests.tracing.helper import create_test_trace_info, create_trace, get_traces


class DefaultTestModel:
Expand Down Expand Up @@ -69,7 +69,7 @@ def test_trace(clear_singleton, with_active_run):
else:
model.predict(2, 5)

trace = get_first_trace()
trace = mlflow.get_last_active_trace()
trace_info = trace.info
assert trace_info.request_id is not None
assert trace_info.experiment_id == "0" # default experiment
Expand Down Expand Up @@ -340,7 +340,7 @@ def some_operation_raise_error(self, x, y):
model.predict(2, 5)

# Trace should be logged even if the function fails, with status code ERROR
trace = get_first_trace()
trace = mlflow.get_last_active_trace()
assert trace.info.request_id is not None
assert trace.info.status == TraceStatus.ERROR
assert trace.info.request_metadata[TraceMetadataKey.INPUTS] == '{"x": 2, "y": 5}'
Expand Down Expand Up @@ -374,7 +374,7 @@ def predict(self, x, y):
output = model.predict(2, 5)

assert output == 7
trace = get_first_trace()
trace = mlflow.get_last_active_trace()
assert trace.info.request_metadata[TraceMetadataKey.INPUTS] == "{}"
assert trace.info.request_metadata[TraceMetadataKey.OUTPUTS] == "7"
TRACE_BUFFER.clear()
Expand Down Expand Up @@ -424,7 +424,7 @@ def square(self, t):
model = TestModel()
model.predict(1, 2)

trace = get_first_trace()
trace = mlflow.get_last_active_trace()
assert trace.info.request_id is not None
assert trace.info.experiment_id == "0" # default experiment
assert trace.info.execution_time_ms >= 0.1 * 1e3 # at least 0.1 sec
Expand Down Expand Up @@ -506,7 +506,7 @@ def predict(self, x, y):
model = TestModel()
model.predict(1, 2)

trace = get_first_trace()
trace = mlflow.get_last_active_trace()
assert trace.info.request_id is not None
assert trace.info.experiment_id == "0" # default experiment
assert trace.info.execution_time_ms >= 0.1 * 1e3 # at least 0.1 sec
Expand Down Expand Up @@ -906,3 +906,24 @@ def search_traces(self, experiment_ids, *args, **kwargs):
mlflow.search_traces(
extract_fields=["span.llm.inputs", "span.invalidname.outputs", "span.llm.inputs.x"]
)


def test_get_last_active_trace(clear_singleton):
assert mlflow.get_last_active_trace() is None

@mlflow.trace()
def predict(x, y):
return x + y

predict(1, 2)
predict(2, 5)
predict(3, 6)

trace = mlflow.get_last_active_trace()
assert trace.info.request_id is not None
assert trace.data.request == '{"x": 3, "y": 6}'

# Mutation of the copy should not affect the original trace logged in the backend
trace.info.status = TraceStatus.ERROR
original_trace = mlflow.MlflowClient().get_trace(trace.info.request_id)
assert original_trace.info.status == TraceStatus.OK
7 changes: 3 additions & 4 deletions tests/tracking/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@
from tests.tracing.conftest import mock_store as mock_store_for_tracing # noqa: F401
from tests.tracing.helper import (
create_test_trace_info,
get_first_trace,
get_traces,
)

Expand Down Expand Up @@ -429,7 +428,7 @@ def square(self, t, request_id, parent_id):
else:
model.predict(1, 2)

request_id = get_first_trace().info.request_id
request_id = mlflow.get_last_active_trace().info.request_id

# Validate that trace is logged to the backend
trace = client.get_trace(request_id)
Expand Down Expand Up @@ -794,7 +793,7 @@ def test_set_and_delete_trace_tag_on_active_trace(clear_singleton, monkeypatch):
client.set_trace_tag(request_id, "foo", "bar")
client.end_trace(request_id)

trace = get_first_trace()
trace = mlflow.get_last_active_trace()
assert trace.info.tags["foo"] == "bar"


Expand All @@ -813,7 +812,7 @@ def test_delete_trace_tag_on_active_trace(clear_singleton, monkeypatch):
client.delete_trace_tag(request_id, "foo")
client.end_trace(request_id)

trace = get_first_trace()
trace = mlflow.get_last_active_trace()
assert "baz" in trace.info.tags
assert "foo" not in trace.info.tags

Expand Down

0 comments on commit 72df4a2

Please sign in to comment.