From cad035cb35ca76a8fc3af0ea22d0ad5c0ccd2084 Mon Sep 17 00:00:00 2001 From: Yeesian Ng Date: Mon, 10 Jun 2024 10:52:28 -0700 Subject: [PATCH] feat: Add enable_tracing to LangchainAgent. PiperOrigin-RevId: 641955580 --- setup.py | 5 +- ...st_reasoning_engine_templates_langchain.py | 78 +++++++++++++++++++ .../reasoning_engines/templates/langchain.py | 45 +++++++++++ vertexai/reasoning_engines/_utils.py | 67 ++++++++++++++++ 4 files changed, 194 insertions(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 9930b4c380..516080fe35 100644 --- a/setup.py +++ b/setup.py @@ -140,6 +140,8 @@ reasoning_engine_extra_require = [ "cloudpickle >= 2.2.1, < 4.0", + "opentelemetry-sdk < 2", + "opentelemetry-exporter-gcp-trace < 2", "pydantic >= 2.6.3, < 3", ] @@ -149,9 +151,10 @@ ] langchain_extra_require = [ - "langchain >= 0.1.16, < 0.2", + "langchain >= 0.1.16, < 0.3", "langchain-core < 0.2", "langchain-google-vertexai < 2", + "openinference-instrumentation-langchain >= 0.1.19, < 0.2", ] langchain_testing_extra_require = list( diff --git a/tests/unit/vertex_langchain/test_reasoning_engine_templates_langchain.py b/tests/unit/vertex_langchain/test_reasoning_engine_templates_langchain.py index 1dc3902823..7270722b8e 100644 --- a/tests/unit/vertex_langchain/test_reasoning_engine_templates_langchain.py +++ b/tests/unit/vertex_langchain/test_reasoning_engine_templates_langchain.py @@ -23,6 +23,7 @@ from vertexai.preview import reasoning_engines from vertexai.preview.generative_models import grounding from vertexai.generative_models import Tool +from vertexai.reasoning_engines import _utils import pytest @@ -89,6 +90,48 @@ def mock_chatvertexai(): yield model_mock +@pytest.fixture +def cloud_trace_exporter_mock(): + with mock.patch.object( + _utils, + "_import_cloud_trace_exporter_or_warn", + ) as cloud_trace_exporter_mock: + yield cloud_trace_exporter_mock + + +@pytest.fixture +def tracer_provider_mock(): + with mock.patch("opentelemetry.sdk.trace.TracerProvider") as tracer_provider_mock: + yield tracer_provider_mock + + +@pytest.fixture +def simple_span_processor_mock(): + with mock.patch( + "opentelemetry.sdk.trace.export.SimpleSpanProcessor" + ) as simple_span_processor_mock: + yield simple_span_processor_mock + + +@pytest.fixture +def langchain_instrumentor_mock(): + with mock.patch.object( + _utils, + "_import_openinference_langchain_or_warn", + ) as langchain_instrumentor_mock: + yield langchain_instrumentor_mock + + +@pytest.fixture +def langchain_instrumentor_none_mock(): + with mock.patch.object( + _utils, + "_import_openinference_langchain_or_warn", + ) as langchain_instrumentor_mock: + langchain_instrumentor_mock.return_value = None + yield langchain_instrumentor_mock + + @pytest.mark.usefixtures("google_auth_mock") class TestLangchainAgent: def setup_method(self): @@ -175,6 +218,41 @@ def test_query(self, langchain_dump_mock): [mock.call.invoke.invoke(input={"input": "test query"}, config=None)] ) + @pytest.mark.usefixtures("caplog") + def test_enable_tracing( + self, + caplog, + cloud_trace_exporter_mock, + tracer_provider_mock, + simple_span_processor_mock, + langchain_instrumentor_mock, + ): + agent = reasoning_engines.LangchainAgent( + model=_TEST_MODEL, + prompt=self.prompt, + output_parser=self.output_parser, + enable_tracing=True, + ) + assert agent._instrumentor is None + agent.set_up() + assert agent._instrumentor is not None + assert ( + "enable_tracing=True but proceeding with tracing disabled" + not in caplog.text + ) + + @pytest.mark.usefixtures("caplog") + def test_enable_tracing_warning(self, caplog, langchain_instrumentor_none_mock): + agent = reasoning_engines.LangchainAgent( + model=_TEST_MODEL, + prompt=self.prompt, + output_parser=self.output_parser, + enable_tracing=True, + ) + assert agent._instrumentor is None + agent.set_up() + assert "enable_tracing=True but proceeding with tracing disabled" in caplog.text + class TestConvertToolsOrRaise: def test_convert_tools_or_raise(self, vertexai_init_mock): diff --git a/vertexai/preview/reasoning_engines/templates/langchain.py b/vertexai/preview/reasoning_engines/templates/langchain.py index 10f8969008..553fd0c7f8 100644 --- a/vertexai/preview/reasoning_engines/templates/langchain.py +++ b/vertexai/preview/reasoning_engines/templates/langchain.py @@ -236,6 +236,7 @@ def __init__( runnable_kwargs: Optional[Mapping[str, Any]] = None, model_builder: Optional[Callable] = None, runnable_builder: Optional[Callable] = None, + enable_tracing: bool = False, ): """Initializes the LangchainAgent. @@ -349,6 +350,9 @@ def __init__( for customizing the orchestration logic of the Agent based on the model returned by `model_builder` and the rest of the input arguments. + enable_tracing (bool): + Optional. Whether to enable tracing in Cloud Trace. Defaults to + False. Raises: TypeError: If there is an invalid tool (e.g. function with an input @@ -376,6 +380,8 @@ def __init__( self._model_builder = model_builder self._runnable = None self._runnable_builder = runnable_builder + self._instrumentor = None + self._enable_tracing = enable_tracing def set_up(self): """Sets up the agent for execution of queries at runtime. @@ -387,6 +393,44 @@ def set_up(self): the ReasoningEngine service for deployment, as it initializes clients that can not be serialized. """ + if self._enable_tracing: + from vertexai.reasoning_engines import _utils + + cloud_trace_exporter = _utils._import_cloud_trace_exporter_or_warn() + openinference_langchain = _utils._import_openinference_langchain_or_warn() + opentelemetry = _utils._import_opentelemetry_or_warn() + opentelemetry_sdk_trace = _utils._import_opentelemetry_sdk_trace_or_warn() + if all( + ( + cloud_trace_exporter, + openinference_langchain, + opentelemetry, + opentelemetry_sdk_trace, + ) + ): + tracer_provider = opentelemetry.trace.get_tracer_provider() + if tracer_provider and _utils._is_noop_tracer_provider(tracer_provider): + # Set a trace provider if it has not been set. + span_exporter = cloud_trace_exporter.CloudTraceSpanExporter( + project_id=self._project, + ) + span_processor = opentelemetry_sdk_trace.export.SimpleSpanProcessor( + span_exporter=span_exporter, + ) + tracer_provider = opentelemetry_sdk_trace.TracerProvider( + active_span_processor=span_processor, + ) + opentelemetry.trace.set_tracer_provider(tracer_provider) + self._instrumentor = openinference_langchain.LangChainInstrumentor() + self._instrumentor.instrument() + else: + from google.cloud.aiplatform import base + + _LOGGER = base.Logger(__name__) + _LOGGER.warning( + "enable_tracing=True but proceeding with tracing disabled " + "because not all packages for tracing have been installed" + ) model_builder = self._model_builder or _default_model_builder self._model = model_builder( model_name=self._model_name, @@ -422,6 +466,7 @@ def clone(self) -> "LangchainAgent": runnable_kwargs=copy.deepcopy(self._runnable_kwargs), model_builder=self._model_builder, runnable_builder=self._runnable_builder, + enable_tracing=self._enable_tracing, ) def query( diff --git a/vertexai/reasoning_engines/_utils.py b/vertexai/reasoning_engines/_utils.py index 70ae8e4cc4..84e526f9d9 100644 --- a/vertexai/reasoning_engines/_utils.py +++ b/vertexai/reasoning_engines/_utils.py @@ -21,6 +21,7 @@ import proto +from google.cloud.aiplatform import base from google.protobuf import struct_pb2 from google.protobuf import json_format @@ -36,6 +37,8 @@ JsonDict = Dict[str, Any] +_LOGGER = base.Logger(__name__) + def to_proto( obj: Union[JsonDict, proto.Message], @@ -195,6 +198,14 @@ def generate_schema( return schema +def _is_noop_tracer_provider(tracer_provider) -> bool: + """Returns True if the tracer_provider is Proxy or NoOp.""" + opentelemetry = _import_opentelemetry_or_warn() + ProxyTracerProvider = opentelemetry.trace.ProxyTracerProvider + NoOpTracerProvider = opentelemetry.trace.NoOpTracerProvider + return isinstance(tracer_provider, (NoOpTracerProvider, ProxyTracerProvider)) + + def _import_cloud_storage_or_raise() -> types.ModuleType: """Tries to import the Cloud Storage module.""" try: @@ -233,3 +244,59 @@ def _import_pydantic_or_raise() -> types.ModuleType: "'pip install google-cloud-aiplatform[reasoningengine]'." ) from e return pydantic + + +def _import_opentelemetry_or_warn() -> Optional[types.ModuleType]: + """Tries to import the opentelemetry module.""" + try: + import opentelemetry # noqa:F401 + + return opentelemetry + except ImportError: + _LOGGER.warning( + "opentelemetry-sdk is not installed. Please call " + "'pip install google-cloud-aiplatform[reasoningengine]'." + ) + return None + + +def _import_opentelemetry_sdk_trace_or_warn() -> Optional[types.ModuleType]: + """Tries to import the opentelemetry.sdk.trace module.""" + try: + import opentelemetry.sdk.trace # noqa:F401 + + return opentelemetry.sdk.trace + except ImportError: + _LOGGER.warning( + "opentelemetry-sdk is not installed. Please call " + "'pip install google-cloud-aiplatform[reasoningengine]'." + ) + return None + + +def _import_cloud_trace_exporter_or_warn() -> Optional[types.ModuleType]: + """Tries to import the opentelemetry.exporter.cloud_trace module.""" + try: + import opentelemetry.exporter.cloud_trace # noqa:F401 + + return opentelemetry.exporter.cloud_trace + except ImportError: + _LOGGER.warning( + "opentelemetry-exporter-gcp-trace is not installed. Please " + "call 'pip install google-cloud-aiplatform[langchain]'." + ) + return None + + +def _import_openinference_langchain_or_warn() -> Optional[types.ModuleType]: + """Tries to import the openinference.instrumentation.langchain module.""" + try: + import openinference.instrumentation.langchain # noqa:F401 + + return openinference.instrumentation.langchain + except ImportError: + _LOGGER.warning( + "openinference-instrumentation-langchain is not installed. Please " + "call 'pip install google-cloud-aiplatform[langchain]'." + ) + return None