Skip to content

Commit

Permalink
LangChain tracing: only end spans if the trace is still active (#12049)
Browse files Browse the repository at this point in the history
Signed-off-by: dbczumar <corey.zumar@databricks.com>
  • Loading branch information
dbczumar committed May 19, 2024
1 parent b53cd2e commit f7c420f
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 8 deletions.
28 changes: 20 additions & 8 deletions mlflow/langchain/langchain_tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def __init__(
self._run_span_mapping: Dict[str, LiveSpan] = {}
self._prediction_context = prediction_context
self._request_id = None
self._root_run_id = None

def _get_span_by_run_id(self, run_id: UUID) -> Optional[LiveSpan]:
if span := self._run_span_mapping.get(str(run_id)):
Expand Down Expand Up @@ -106,6 +107,7 @@ def _start_span(
tags=dependencies_schema,
)
self._request_id = span.request_id
self._root_run_id = run_id

self._run_span_mapping[str(run_id)] = span
return span
Expand All @@ -128,12 +130,22 @@ def _get_parent_span(self, parent_run_id) -> Optional[LiveSpan]:

def _end_span(
self,
run_id: UUID,
span: LiveSpan,
outputs=None,
attributes=None,
status=SpanStatus(SpanStatusCode.OK),
):
"""Close MLflow Span (or Trace if it is root component)"""
root_run_active = str(self._root_run_id) in self._run_span_mapping
self._run_span_mapping.pop(str(run_id), None)
if not root_run_active:
# If the root run is not found in the mapping, it means that the root span is already
# closed. In this case, the trace is likely no longer active, so we do not attempt
# to write the span to the trace. For example, this occurs during streaming inference
# if the generator returned by stream() is not consumed completely
return

with set_prediction_context(self._prediction_context):
self._mlflow_client.end_span(
request_id=span.request_id,
Expand Down Expand Up @@ -252,7 +264,7 @@ def on_llm_end(self, response: LLMResult, *, run_id: UUID, **kwargs: Any):
"""End the span for an LLM run."""
llm_span = self._get_span_by_run_id(run_id)
outputs = response.dict()
self._end_span(llm_span, outputs=outputs)
self._end_span(run_id, llm_span, outputs=outputs)

def on_llm_error(
self,
Expand All @@ -264,7 +276,7 @@ def on_llm_error(
"""Handle an error for an LLM run."""
llm_span = self._get_span_by_run_id(run_id)
llm_span.add_event(SpanEvent.from_exception(error))
self._end_span(llm_span, status=SpanStatus(SpanStatusCode.ERROR, str(error)))
self._end_span(run_id, llm_span, status=SpanStatus(SpanStatusCode.ERROR, str(error)))

def on_chain_start(
self,
Expand Down Expand Up @@ -304,7 +316,7 @@ def on_chain_end(
chain_span = self._get_span_by_run_id(run_id)
if inputs:
chain_span.set_inputs(inputs)
self._end_span(chain_span, outputs=outputs)
self._end_span(run_id, chain_span, outputs=outputs)

def on_chain_error(
self,
Expand All @@ -319,7 +331,7 @@ def on_chain_error(
if inputs:
chain_span.set_inputs(inputs)
chain_span.add_event(SpanEvent.from_exception(error))
self._end_span(chain_span, status=SpanStatus(SpanStatusCode.ERROR, str(error)))
self._end_span(run_id, chain_span, status=SpanStatus(SpanStatusCode.ERROR, str(error)))

def on_tool_start(
self,
Expand Down Expand Up @@ -349,7 +361,7 @@ def on_tool_start(
def on_tool_end(self, output: Any, *, run_id: UUID, **kwargs: Any):
"""Run when tool ends running."""
tool_span = self._get_span_by_run_id(run_id)
self._end_span(tool_span, outputs=str(output))
self._end_span(run_id, tool_span, outputs=str(output))

def on_tool_error(
self,
Expand All @@ -361,7 +373,7 @@ def on_tool_error(
"""Run when tool errors."""
tool_span = self._get_span_by_run_id(run_id)
tool_span.add_event(SpanEvent.from_exception(error))
self._end_span(tool_span, status=SpanStatus(SpanStatusCode.ERROR, str(error)))
self._end_span(run_id, tool_span, status=SpanStatus(SpanStatusCode.ERROR, str(error)))

def on_retriever_start(
self,
Expand Down Expand Up @@ -390,7 +402,7 @@ def on_retriever_start(
def on_retriever_end(self, documents: Sequence[Document], *, run_id: UUID, **kwargs: Any):
"""Run when Retriever ends running."""
retriever_span = self._get_span_by_run_id(run_id)
self._end_span(retriever_span, outputs=documents)
self._end_span(run_id, retriever_span, outputs=documents)

def on_retriever_error(
self,
Expand All @@ -402,7 +414,7 @@ def on_retriever_error(
"""Run when Retriever errors."""
retriever_span = self._get_span_by_run_id(run_id)
retriever_span.add_event(SpanEvent.from_exception(error))
self._end_span(retriever_span, status=SpanStatus(SpanStatusCode.ERROR, str(error)))
self._end_span(run_id, retriever_span, status=SpanStatus(SpanStatusCode.ERROR, str(error)))

def on_agent_action(
self,
Expand Down
47 changes: 47 additions & 0 deletions tests/langchain/test_langchain_tracer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import uuid
from concurrent.futures import ThreadPoolExecutor
from typing import Any, List, Optional
from unittest import mock

import pytest
Expand All @@ -21,6 +22,7 @@
from mlflow.entities import Trace
from mlflow.entities.span_event import SpanEvent
from mlflow.entities.span_status import SpanStatus, SpanStatusCode
from mlflow.exceptions import MlflowException
from mlflow.langchain import _LangChainModelWrapper
from mlflow.langchain.langchain_tracer import MlflowLangchainTracer
from mlflow.pyfunc.context import Context
Expand Down Expand Up @@ -541,3 +543,48 @@ def worker_function(worker_id):
traces = get_traces()
assert len(traces) == 10
assert all(len(trace.data.spans) == 1 for trace in traces)


def test_tracer_does_not_add_spans_to_trace_after_root_run_has_finished(clear_singleton):
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.chat_models.base import SimpleChatModel
from langchain.schema.messages import BaseMessage

class FakeChatModel(SimpleChatModel):
"""Fake Chat Model wrapper for testing purposes."""

def _call(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
return TEST_CONTENT

@property
def _llm_type(self) -> str:
return "fake chat model"

run_id_for_on_chain_end = None

class ExceptionCatchingTracer(MlflowLangchainTracer):
def on_chain_end(self, outputs, *, run_id, inputs=None, **kwargs):
nonlocal run_id_for_on_chain_end
run_id_for_on_chain_end = run_id
super().on_chain_end(outputs, run_id=run_id, inputs=inputs, **kwargs)

prompt = SystemMessagePromptTemplate.from_template("You are a nice assistant.") + "{question}"
chain = prompt | FakeChatModel() | StrOutputParser()

tracer = ExceptionCatchingTracer()

chain.invoke(
"What is MLflow?",
config={"callbacks": [tracer]},
)

with pytest.raises(MlflowException, match="Span for run_id .* not found."):
# After the chain is invoked, verify that the tracer no longer holds references to spans,
# ensuring that the tracer does not add spans to the trace after the root run has finished
tracer.on_chain_end({"output": "test output"}, run_id=run_id_for_on_chain_end, inputs=None)

0 comments on commit f7c420f

Please sign in to comment.