Skip to content

Commit

Permalink
[MLflow] Update mlflow langchain pyfunc.load_model to correctly write…
Browse files Browse the repository at this point in the history
… tags to Tracing Info (#12050)

Signed-off-by: Sunish Sheth <sunishsheth2009@gmail.com>
  • Loading branch information
sunishsheth2009 committed May 19, 2024
1 parent 2c7906a commit b53cd2e
Show file tree
Hide file tree
Showing 6 changed files with 75 additions and 7 deletions.
26 changes: 24 additions & 2 deletions mlflow/langchain/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import contextlib
import functools
import inspect
import json
import logging
import os
import warnings
Expand Down Expand Up @@ -603,8 +604,9 @@ def _load_model(local_model_path, flavor_conf):


class _LangChainModelWrapper:
def __init__(self, lc_model):
def __init__(self, lc_model, model_path=None):
self.lc_model = lc_model
self.model_path = model_path

def predict(
self,
Expand Down Expand Up @@ -638,6 +640,24 @@ def predict(

return self._predict_with_callbacks(data, params, callback_handlers=callbacks)

def _update_tracing_prediction_context(self, callback_handlers):
from mlflow.langchain.langchain_tracer import MlflowLangchainTracer

if callback_handlers:
# TODO: fix this if callback_handlers contains multiple handlers
tracer = callback_handlers[0]
if isinstance(tracer, MlflowLangchainTracer) and self.model_path:
model = Model.load(self.model_path)
context = tracer._prediction_context
if model.metadata and context:
dependencies_schema = model.metadata.get("dependencies_schemas", {})
context.update(
dependencies_schema={
dependency: json.dumps(schema)
for dependency, schema in dependencies_schema.items()
}
)

@experimental
def _predict_with_callbacks(
self,
Expand All @@ -659,6 +679,7 @@ def _predict_with_callbacks(
"""
from mlflow.langchain.api_request_parallel_processor import process_api_requests

self._update_tracing_prediction_context(callback_handlers)
messages, return_first_element = self._prepare_predict_messages(data)
results = process_api_requests(
lc_model=self.lc_model,
Expand Down Expand Up @@ -751,6 +772,7 @@ def _predict_stream_with_callbacks(
process_stream_request,
)

self._update_tracing_prediction_context(callback_handlers)
data = self._prepare_predict_stream_messages(data)
return process_stream_request(
lc_model=self.lc_model,
Expand Down Expand Up @@ -831,7 +853,7 @@ def _load_pyfunc(path: str, model_config: Optional[Dict[str, Any]] = None):
path: Local filesystem path to the MLflow Model with the ``langchain`` flavor.
"""
wrapper_cls = _TestLangChainWrapper if _MLFLOW_TESTING.get() else _LangChainModelWrapper
return wrapper_cls(_load_model_from_local_fs(path))
return wrapper_cls(_load_model_from_local_fs(path), path)


def _load_model_from_local_fs(local_model_path):
Expand Down
11 changes: 10 additions & 1 deletion mlflow/langchain/langchain_tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,17 @@ def _start_span(
)
else:
# When parent_run_id is None, this is root component so start trace
dependencies_schema = (
self._prediction_context.dependencies_schema
if self._prediction_context
else None
)
span = self._mlflow_client.start_trace(
name=span_name, span_type=span_type, inputs=inputs, attributes=attributes
name=span_name,
span_type=span_type,
inputs=inputs,
attributes=attributes,
tags=dependencies_schema,
)
self._request_id = span.request_id

Expand Down
9 changes: 9 additions & 0 deletions mlflow/pyfunc/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,15 @@ class Context:
request_id: str
# Whether the current prediction request is as a part of MLflow model evaluation.
is_evaluate: bool = False
# The schema of the dependencies to be added into the tag of trace info.
dependencies_schema: Optional[dict] = None

def update(self, **kwargs):
for key, value in kwargs.items():
if hasattr(self, key):
setattr(self, key, value)
else:
raise AttributeError(f"Context has no attribute named '{key}'")


@contextlib.contextmanager
Expand Down
32 changes: 28 additions & 4 deletions tests/langchain/test_langchain_model_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,15 @@
from mlflow.deployments import PredictionsResponse
from mlflow.exceptions import MlflowException
from mlflow.langchain.api_request_parallel_processor import APIRequest
from mlflow.langchain.langchain_tracer import MlflowLangchainTracer
from mlflow.langchain.utils import (
_LC_MIN_VERSION_SUPPORT_CHAT_OPEN_AI,
IS_PICKLE_SERIALIZATION_RESTRICTED,
)
from mlflow.models import Model
from mlflow.models.resources import DatabricksServingEndpoint, DatabricksVectorSearchIndex, Resource
from mlflow.models.signature import ModelSignature, Schema, infer_signature
from mlflow.pyfunc.context import Context
from mlflow.tracing.processor.inference_table import _HEADER_REQUEST_ID_KEY
from mlflow.tracking.artifact_utils import _download_artifact_from_uri
from mlflow.types.schema import Array, ColSpec, DataType, Object, Property
Expand All @@ -72,6 +74,7 @@
)

from tests.helper_functions import pyfunc_serve_and_score_model
from tests.tracing.conftest import clear_singleton as clear_trace_singleton # noqa: F401
from tests.tracing.export.test_inference_table_exporter import _REQUEST_ID

# this kwarg was added in langchain_community 0.0.27, and
Expand Down Expand Up @@ -2404,6 +2407,25 @@ def test_save_load_chain_as_code(chain_model_signature):
}
]
}
request_id = "mock_request_id"
tracer = MlflowLangchainTracer(prediction_context=Context(request_id))
input_example = {"messages": [{"role": "user", "content": "What is MLflow?"}]}
response = pyfunc_loaded_model._model_impl._predict_with_callbacks(
data=input_example, callback_handlers=[tracer]
)
assert response["choices"][0]["message"]["content"] == "Databricks"
trace = mlflow.get_trace(tracer._request_id)
assert trace.info.tags["vector_search_index"] == json.dumps(
[
{
"doc_uri": "doc-uri",
"name": "vector_search_index",
"other_columns": ["column1", "column2"],
"primary_key": "primary-key",
"text_column": "text-column",
}
]
)


@pytest.mark.skipif(
Expand Down Expand Up @@ -2894,7 +2916,9 @@ def retrieve_history(input):
}


def test_langchain_model_inject_callback_in_model_serving(monkeypatch, model_path):
def test_langchain_model_inject_callback_in_model_serving(
clear_trace_singleton, monkeypatch, model_path
):
# Emulate the model serving environment
monkeypatch.setenv("IS_IN_DATABRICKS_MODEL_SERVING_ENV", "true")

Expand All @@ -2914,10 +2938,11 @@ def test_langchain_model_inject_callback_in_model_serving(monkeypatch, model_pat

assert len(_TRACE_BUFFER) == 1
assert _REQUEST_ID in _TRACE_BUFFER
_TRACE_BUFFER.clear()


def test_langchain_model_not_inject_callback_when_disabled(monkeypatch, model_path):
def test_langchain_model_not_inject_callback_when_disabled(
clear_trace_singleton, monkeypatch, model_path
):
# Emulate the model serving environment
monkeypatch.setenv("IS_IN_DATABRICKS_MODEL_SERVING_ENV", "true")

Expand All @@ -2934,7 +2959,6 @@ def test_langchain_model_not_inject_callback_when_disabled(monkeypatch, model_pa
from mlflow.tracing.export.inference_table import _TRACE_BUFFER

assert _TRACE_BUFFER == {}
_TRACE_BUFFER.clear()


@pytest.mark.skipif(
Expand Down
2 changes: 2 additions & 0 deletions tests/pyfunc/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ def set_context(context):
with set_prediction_context(context):
time.sleep(0.2 * random.random())
assert get_prediction_context() == context
context.update(is_evaluate=not context.is_evaluate)
assert get_prediction_context() == context

threads = []
for i in range(10):
Expand Down
2 changes: 2 additions & 0 deletions tests/tracing/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from mlflow.entities import TraceInfo
from mlflow.entities.trace_status import TraceStatus
from mlflow.tracing.display import IPythonTraceDisplayHandler
from mlflow.tracing.export.inference_table import _TRACE_BUFFER
from mlflow.tracing.fluent import TRACE_BUFFER
from mlflow.tracing.provider import _TRACER_PROVIDER_INITIALIZED
from mlflow.tracing.trace_manager import InMemoryTraceManager
Expand All @@ -23,6 +24,7 @@ def clear_singleton():
InMemoryTraceManager._instance = None
IPythonTraceDisplayHandler._instance = None
TRACE_BUFFER.clear()
_TRACE_BUFFER.clear()

# Tracer provider also needs to be reset as it may hold reference to the singleton
with _TRACER_PROVIDER_SET_ONCE._lock:
Expand Down

0 comments on commit b53cd2e

Please sign in to comment.