diff --git a/mlflow/langchain/langchain_tracer.py b/mlflow/langchain/langchain_tracer.py index fb650640dbc61..1ce3b000fc373 100644 --- a/mlflow/langchain/langchain_tracer.py +++ b/mlflow/langchain/langchain_tracer.py @@ -64,6 +64,7 @@ def __init__( self._run_span_mapping: Dict[str, LiveSpan] = {} self._prediction_context = prediction_context self._request_id = None + self._root_run_id = None def _get_span_by_run_id(self, run_id: UUID) -> Optional[LiveSpan]: if span := self._run_span_mapping.get(str(run_id)): @@ -106,6 +107,7 @@ def _start_span( tags=dependencies_schema, ) self._request_id = span.request_id + self._root_run_id = run_id self._run_span_mapping[str(run_id)] = span return span @@ -128,12 +130,22 @@ def _get_parent_span(self, parent_run_id) -> Optional[LiveSpan]: def _end_span( self, + run_id: UUID, span: LiveSpan, outputs=None, attributes=None, status=SpanStatus(SpanStatusCode.OK), ): """Close MLflow Span (or Trace if it is root component)""" + root_run_active = str(self._root_run_id) in self._run_span_mapping + self._run_span_mapping.pop(str(run_id), None) + if not root_run_active: + # If the root run is not found in the mapping, it means that the root span is already + # closed. In this case, the trace is likely no longer active, so we do not attempt + # to write the span to the trace. For example, this occurs during streaming inference + # if the generator returned by stream() is not consumed completely + return + with set_prediction_context(self._prediction_context): self._mlflow_client.end_span( request_id=span.request_id, @@ -252,7 +264,7 @@ def on_llm_end(self, response: LLMResult, *, run_id: UUID, **kwargs: Any): """End the span for an LLM run.""" llm_span = self._get_span_by_run_id(run_id) outputs = response.dict() - self._end_span(llm_span, outputs=outputs) + self._end_span(run_id, llm_span, outputs=outputs) def on_llm_error( self, @@ -264,7 +276,7 @@ def on_llm_error( """Handle an error for an LLM run.""" llm_span = self._get_span_by_run_id(run_id) llm_span.add_event(SpanEvent.from_exception(error)) - self._end_span(llm_span, status=SpanStatus(SpanStatusCode.ERROR, str(error))) + self._end_span(run_id, llm_span, status=SpanStatus(SpanStatusCode.ERROR, str(error))) def on_chain_start( self, @@ -304,7 +316,7 @@ def on_chain_end( chain_span = self._get_span_by_run_id(run_id) if inputs: chain_span.set_inputs(inputs) - self._end_span(chain_span, outputs=outputs) + self._end_span(run_id, chain_span, outputs=outputs) def on_chain_error( self, @@ -319,7 +331,7 @@ def on_chain_error( if inputs: chain_span.set_inputs(inputs) chain_span.add_event(SpanEvent.from_exception(error)) - self._end_span(chain_span, status=SpanStatus(SpanStatusCode.ERROR, str(error))) + self._end_span(run_id, chain_span, status=SpanStatus(SpanStatusCode.ERROR, str(error))) def on_tool_start( self, @@ -349,7 +361,7 @@ def on_tool_start( def on_tool_end(self, output: Any, *, run_id: UUID, **kwargs: Any): """Run when tool ends running.""" tool_span = self._get_span_by_run_id(run_id) - self._end_span(tool_span, outputs=str(output)) + self._end_span(run_id, tool_span, outputs=str(output)) def on_tool_error( self, @@ -361,7 +373,7 @@ def on_tool_error( """Run when tool errors.""" tool_span = self._get_span_by_run_id(run_id) tool_span.add_event(SpanEvent.from_exception(error)) - self._end_span(tool_span, status=SpanStatus(SpanStatusCode.ERROR, str(error))) + self._end_span(run_id, tool_span, status=SpanStatus(SpanStatusCode.ERROR, str(error))) def on_retriever_start( self, @@ -390,7 +402,7 @@ def on_retriever_start( def on_retriever_end(self, documents: Sequence[Document], *, run_id: UUID, **kwargs: Any): """Run when Retriever ends running.""" retriever_span = self._get_span_by_run_id(run_id) - self._end_span(retriever_span, outputs=documents) + self._end_span(run_id, retriever_span, outputs=documents) def on_retriever_error( self, @@ -402,7 +414,7 @@ def on_retriever_error( """Run when Retriever errors.""" retriever_span = self._get_span_by_run_id(run_id) retriever_span.add_event(SpanEvent.from_exception(error)) - self._end_span(retriever_span, status=SpanStatus(SpanStatusCode.ERROR, str(error))) + self._end_span(run_id, retriever_span, status=SpanStatus(SpanStatusCode.ERROR, str(error))) def on_agent_action( self, diff --git a/tests/langchain/test_langchain_tracer.py b/tests/langchain/test_langchain_tracer.py index 06552c06a93a5..edb2e377d5237 100644 --- a/tests/langchain/test_langchain_tracer.py +++ b/tests/langchain/test_langchain_tracer.py @@ -1,5 +1,6 @@ import uuid from concurrent.futures import ThreadPoolExecutor +from typing import Any, List, Optional from unittest import mock import pytest @@ -21,6 +22,7 @@ from mlflow.entities import Trace from mlflow.entities.span_event import SpanEvent from mlflow.entities.span_status import SpanStatus, SpanStatusCode +from mlflow.exceptions import MlflowException from mlflow.langchain import _LangChainModelWrapper from mlflow.langchain.langchain_tracer import MlflowLangchainTracer from mlflow.pyfunc.context import Context @@ -541,3 +543,48 @@ def worker_function(worker_id): traces = get_traces() assert len(traces) == 10 assert all(len(trace.data.spans) == 1 for trace in traces) + + +def test_tracer_does_not_add_spans_to_trace_after_root_run_has_finished(clear_singleton): + from langchain.callbacks.manager import CallbackManagerForLLMRun + from langchain.chat_models.base import SimpleChatModel + from langchain.schema.messages import BaseMessage + + class FakeChatModel(SimpleChatModel): + """Fake Chat Model wrapper for testing purposes.""" + + def _call( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> str: + return TEST_CONTENT + + @property + def _llm_type(self) -> str: + return "fake chat model" + + run_id_for_on_chain_end = None + + class ExceptionCatchingTracer(MlflowLangchainTracer): + def on_chain_end(self, outputs, *, run_id, inputs=None, **kwargs): + nonlocal run_id_for_on_chain_end + run_id_for_on_chain_end = run_id + super().on_chain_end(outputs, run_id=run_id, inputs=inputs, **kwargs) + + prompt = SystemMessagePromptTemplate.from_template("You are a nice assistant.") + "{question}" + chain = prompt | FakeChatModel() | StrOutputParser() + + tracer = ExceptionCatchingTracer() + + chain.invoke( + "What is MLflow?", + config={"callbacks": [tracer]}, + ) + + with pytest.raises(MlflowException, match="Span for run_id .* not found."): + # After the chain is invoked, verify that the tracer no longer holds references to spans, + # ensuring that the tracer does not add spans to the trace after the root run has finished + tracer.on_chain_end({"output": "test output"}, run_id=run_id_for_on_chain_end, inputs=None)