Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Langchain tracer improvement #11992

Merged
merged 13 commits into from
May 16, 2024
10 changes: 7 additions & 3 deletions mlflow/entities/span.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,10 +166,14 @@ def get_attribute(self, key: str) -> Optional[Any]:
"""
return self._attributes.get(key)

def to_dict(self):
def to_dict(self, dump_events=False):
# NB: OpenTelemetry Span has to_json() method, but it will write many fields that
# we don't use e.g. links, kind, resource, trace_state, etc. So we manually
# cherry-pick the fields we need here.
if dump_events:
events = [event.json() for event in self.events]
else:
events = [asdict(event) for event in self.events]
return {
"name": self.name,
"context": {
Expand All @@ -179,10 +183,10 @@ def to_dict(self):
"parent_id": self.parent_id,
"start_time": self.start_time_ns,
"end_time": self.end_time_ns,
"status_code": self.status.status_code,
"status_code": self.status.status_code.value,
"status_message": self.status.description,
"attributes": dict(self._span.attributes),
"events": [asdict(event) for event in self.events],
"events": events,
}

@classmethod
Expand Down
27 changes: 27 additions & 0 deletions mlflow/entities/span_event.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import json
import sys
import time
import traceback
from dataclasses import dataclass, field
from datetime import datetime
from typing import Dict

from opentelemetry.util.types import AttributeValue
Expand Down Expand Up @@ -57,3 +59,28 @@ def _get_stacktrace(error: BaseException) -> str:
return (msg + "\n\n".join(tb)).strip()
except Exception:
return msg

def json(self):
return {
"name": self.name,
"timestamp": self.timestamp,
"attributes": json.dumps(self.attributes, cls=CustomEncoder)
if self.attributes
else None,
}


class CustomEncoder(json.JSONEncoder):
"""
Custom encoder to handle json serialization.
"""

def default(self, o):
try:
return super().default(o)
except TypeError:
# convert datetime to string format by default
if isinstance(o, datetime):
return o.isoformat()
# convert object direct to string to avoid error in serialization
return str(o)
7 changes: 7 additions & 0 deletions mlflow/environment_variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -604,3 +604,10 @@ def get(self):
MLFLOW_REQUIREMENTS_INFERENCE_RAISE_ERRORS = _BooleanEnvironmentVariable(
"MLFLOW_REQUIREMENTS_INFERENCE_RAISE_ERRORS", False
)

#: Private configuration option.
#: Whether to use mlflow langchain tracer for rag tracing. This should be set in model
#: serving for rag models.
_USE_MLFLOW_LANGCHAIN_TRACER_FOR_RAG_TRACING = _BooleanEnvironmentVariable(
"USE_MLFLOW_LANGCHAIN_TRACER_FOR_RAG_TRACING", False
harupy marked this conversation as resolved.
Show resolved Hide resolved
)
166 changes: 149 additions & 17 deletions mlflow/langchain/langchain_tracer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import json
import logging
from contextvars import Context
from dataclasses import asdict, dataclass
from datetime import datetime
from typing import Any, Dict, List, Optional, Sequence, Union
from uuid import UUID

Expand All @@ -18,11 +20,21 @@
import mlflow
from mlflow import MlflowClient
from mlflow.entities import LiveSpan, SpanEvent, SpanStatus, SpanStatusCode, SpanType
from mlflow.environment_variables import _USE_MLFLOW_LANGCHAIN_TRACER_FOR_RAG_TRACING
from mlflow.exceptions import MlflowException
from mlflow.pyfunc.context import set_prediction_context
from mlflow.langchain.utils import (
DATABRICKS_VECTOR_SEARCH_DOC_URI,
DATABRICKS_VECTOR_SEARCH_PRIMARY_KEY,
get_databricks_vector_search_key,
)
from mlflow.pyfunc.context import Context, set_prediction_context
from mlflow.tracing.export.inference_table import pop_trace
from mlflow.utils.autologging_utils import ExceptionSafeAbstractClass

_logger = logging.getLogger(__name__)
# Vector Search index column names
VS_INDEX_ID_COL = "chunk_id"
VS_INDEX_DOC_URL_COL = "doc_uri"


class MlflowLangchainTracer(BaseCallbackHandler, metaclass=ExceptionSafeAbstractClass):
Expand Down Expand Up @@ -61,6 +73,21 @@ def __init__(
self._parent_span = parent_span
self._run_span_mapping: Dict[str, LiveSpan] = {}
self._prediction_context = prediction_context
self._request_id = None

def _dump_trace(self) -> str:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we introduce a method which builds a Trace from the dumped str (or does this already exist)? It will be necessary to reconstruct the trace from payload logs of a served model.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's discuss on the doc and if we really need it I can file a follow-up PR :D

"""
This method is only used to get the trace data from the buffer in databricks
serving, it should return the trace in dictionary format and then dump to string.
"""

def _default_converter(o):
if isinstance(o, datetime):
return o.isoformat()
serena-ruan marked this conversation as resolved.
Show resolved Hide resolved
raise TypeError(f"Object of type {o.__class__.__name__} is not JSON serializable")

trace = pop_trace(self._request_id)
return json.dumps(trace, default=_default_converter)

def _get_span_by_run_id(self, run_id: UUID) -> Optional[LiveSpan]:
if span := self._run_span_mapping.get(str(run_id)):
Expand All @@ -77,24 +104,25 @@ def _start_span(
attributes: Optional[Dict[str, Any]] = None,
) -> LiveSpan:
"""Start MLflow Span (or Trace if it is root component)"""
parent = self._get_parent_span(parent_run_id)
if parent:
span = self._mlflow_client.start_span(
name=span_name,
request_id=parent.request_id,
parent_id=parent.span_id,
span_type=span_type,
inputs=inputs,
attributes=attributes,
)
else:
# When parent_run_id is None, this is root component so start trace
with set_prediction_context(self._prediction_context):
with set_prediction_context(self._prediction_context):
parent = self._get_parent_span(parent_run_id)
if parent:
span = self._mlflow_client.start_span(
name=span_name,
request_id=parent.request_id,
parent_id=parent.span_id,
span_type=span_type,
inputs=inputs,
attributes=attributes,
)
else:
# When parent_run_id is None, this is root component so start trace
span = self._mlflow_client.start_trace(
name=span_name, span_type=span_type, inputs=inputs, attributes=attributes
)
self._request_id = span.request_id

self._run_span_mapping[str(run_id)] = span
self._run_span_mapping[str(run_id)] = span
return span

def _get_parent_span(self, parent_run_id) -> Optional[LiveSpan]:
Expand Down Expand Up @@ -175,6 +203,8 @@ def on_llm_start(
"""Run when LLM (non-chat models) starts running."""
if metadata:
kwargs.update({"metadata": metadata})
if _USE_MLFLOW_LANGCHAIN_TRACER_FOR_RAG_TRACING.get():
prompts = convert_llm_inputs(prompts)
self._start_span(
span_name=name or self._assign_span_name(serialized, "llm"),
parent_run_id=parent_run_id,
Expand Down Expand Up @@ -238,7 +268,10 @@ def on_retry(
def on_llm_end(self, response: LLMResult, *, run_id: UUID, **kwargs: Any):
"""End the span for an LLM run."""
llm_span = self._get_span_by_run_id(run_id)
self._end_span(llm_span, outputs=response.dict())
outputs = response.dict()
if _USE_MLFLOW_LANGCHAIN_TRACER_FOR_RAG_TRACING.get():
outputs = convert_llm_outputs(outputs)
self._end_span(llm_span, outputs=outputs)

def on_llm_error(
self,
Expand Down Expand Up @@ -376,6 +409,8 @@ def on_retriever_start(
def on_retriever_end(self, documents: Sequence[Document], *, run_id: UUID, **kwargs: Any):
"""Run when Retriever ends running."""
retriever_span = self._get_span_by_run_id(run_id)
if _USE_MLFLOW_LANGCHAIN_TRACER_FOR_RAG_TRACING.get():
documents = convert_retriever_outputs(documents)
self._end_span(retriever_span, outputs=documents)

def on_retriever_error(
Expand Down Expand Up @@ -457,3 +492,100 @@ def flush_tracker(self):
self._reset()
except Exception as e:
_logger.debug(f"Failed to flush MLflow tracer due to error {e}.")


@dataclass
class Chunk:
chunk_id: Optional[str] = None
doc_uri: Optional[str] = None
content: Optional[str] = None


def convert_retriever_outputs(
documents: List[Document],
) -> Dict[str, List[Dict[str, str]]]:
"""
Convert Retriever outputs format from
[
Document(page_content="...", metadata={...}, type="Document"),
...
]
to
{
"chunks": [
{
"chunk_id": "...",
"doc_uri": "...",
"content": "...",
},
...
]
}
"""
chunk_id_col = (
get_databricks_vector_search_key(DATABRICKS_VECTOR_SEARCH_PRIMARY_KEY) or VS_INDEX_ID_COL
)
doc_uri_col = (
get_databricks_vector_search_key(DATABRICKS_VECTOR_SEARCH_DOC_URI) or VS_INDEX_DOC_URL_COL
)
return {
"chunks": [
asdict(
Chunk(
chunk_id=doc.metadata.get(chunk_id_col),
doc_uri=doc.metadata.get(doc_uri_col),
content=doc.page_content,
)
)
for doc in documents
]
}


# TODO: multiple prompts is not supported for now
# we should update once rag eval supports it.
def convert_llm_inputs(
prompts: List[str],
) -> Dict[str, str]:
"""
Convert LLM inputs format from
["prompt1"...]
to
{"prompt": "prompt1"}
"""
if len(prompts) == 0:
prompt = ""
elif len(prompts) == 1:
prompt = prompts[0]
else:
raise NotImplementedError(f"Multiple prompts not supported yet. Got: {prompts}")
return {"prompt": prompt}


# TODO: This assumes we are not in batch situation where we have multiple generations per
# input, we should handle that case once rag eval supports it.
def convert_llm_outputs(
outputs: Dict[str, Any],
) -> Dict[str, Any]:
"""
Convert LLM outputs format from
{
# generations is List[List[Generation]] because
# each input could have multiple candidate generations.
"generations": [[{"text": "...", "generation_info": {...}, "type": "Generation"}]],
"llm_output": {...},
"run": [{"run_id": "..."}],
}
serena-ruan marked this conversation as resolved.
Show resolved Hide resolved
to
{
# Convert to str by extracting first text field of Generation
"generated_text": "generated_text1",
}
"""
generated_text = ""
if "generations" in outputs:
generations: List[List[Dict[str, Any]]] = outputs["generations"]
if len(generations) > 0 and len(generations[0]) > 0:
first_generation: Dict[str, Any] = generations[0][0]
generated_text = first_generation.get("text", "")
return {"generated_text": generated_text}
50 changes: 49 additions & 1 deletion mlflow/langchain/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import warnings
from functools import lru_cache
from importlib.util import find_spec
from typing import Callable, NamedTuple
from typing import Callable, List, NamedTuple, Optional

import cloudpickle
import yaml
Expand Down Expand Up @@ -729,3 +729,51 @@ def register_pydantic_v1_serializer_cm():
yield
finally:
unregister_pydantic_serializer()


DATABRICKS_VECTOR_SEARCH_PRIMARY_KEY = "__databricks_vector_search_primary_key__"
DATABRICKS_VECTOR_SEARCH_TEXT_COLUMN = "__databricks_vector_search_text_column__"
DATABRICKS_VECTOR_SEARCH_DOC_URI = "__databricks_vector_search_doc_uri__"
DATABRICKS_VECTOR_SEARCH_OTHER_COLUMNS = "__databricks_vector_search_other_columns__"


def set_vector_search_schema(
primary_key: str,
text_column: str = "",
doc_uri: str = "",
other_columns: Optional[List[str]] = None,
):
"""
After defining your vector store in a Python file or notebook, call
set_vector_search_schema() so that we can correctly map the vector index
columns. These columns would be used during tracing and in the review UI.

Args:
primary_key: The primary key of the vector index.
text_column: The name of the text column to use for the embeddings.
doc_uri: The name of the column that contains the document URI.
other_columns: A list of other columns that are part of the vector index
that need to be retrieved during trace logging.
Note: Make sure the text column specified is in the index.

Example:

.. code-block:: python

from mlflow.langchain.utils import set_vector_search_schema

set_vector_search_schema(
primary_key="chunk_id",
text_column="chunk_text",
doc_uri="doc_uri",
other_columns=["title"],
)
"""
globals()[DATABRICKS_VECTOR_SEARCH_PRIMARY_KEY] = primary_key
globals()[DATABRICKS_VECTOR_SEARCH_TEXT_COLUMN] = text_column
globals()[DATABRICKS_VECTOR_SEARCH_DOC_URI] = doc_uri
globals()[DATABRICKS_VECTOR_SEARCH_OTHER_COLUMNS] = other_columns or []


def get_databricks_vector_search_key(key):
return globals().get(key)
2 changes: 1 addition & 1 deletion mlflow/tracing/export/inference_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def _format_spans(self, mlflow_span: LiveSpan) -> Dict[str, Any]:
string instead of a dictionary. Therefore, the attributes are converted from
Dict[str, str(json)] to str(json).
"""
span_dict = mlflow_span.to_dict()
span_dict = mlflow_span.to_dict(dump_events=True)
attributes = span_dict["attributes"]
# deserialize each attribute value and then serialize the whole dictionary
attributes = {k: self._decode_attribute(v) for k, v in attributes.items()}
Expand Down
4 changes: 2 additions & 2 deletions mlflow/tracing/export/mlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from mlflow.tracing.display.display_handler import IPythonTraceDisplayHandler
from mlflow.tracing.fluent import TRACE_BUFFER
from mlflow.tracing.trace_manager import InMemoryTraceManager
from mlflow.tracing.utils import maybe_get_evaluation_request_id
from mlflow.tracing.utils import maybe_get_request_id
from mlflow.tracking.client import MlflowClient

_logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -63,7 +63,7 @@ def export(self, root_spans: Sequence[ReadableSpan]):
if eval_request_id := trace.info.tags.get(TraceTagKey.EVAL_REQUEST_ID):
TRACE_BUFFER[eval_request_id] = trace

if not maybe_get_evaluation_request_id():
if not maybe_get_request_id(is_evaluate=True):
# Display the trace in the UI if the trace is not generated from within
# an MLflow model evaluation context
self._display_handler.display_traces([trace])
Expand Down
Loading
Loading