Skip to content

Commit

Permalink
feat: Add enable_tracing to LangchainAgent.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 641955580
  • Loading branch information
Yeesian Ng authored and Copybara-Service committed Jun 10, 2024
1 parent a78a35e commit cad035c
Show file tree
Hide file tree
Showing 4 changed files with 194 additions and 1 deletion.
5 changes: 4 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,8 @@

reasoning_engine_extra_require = [
"cloudpickle >= 2.2.1, < 4.0",
"opentelemetry-sdk < 2",
"opentelemetry-exporter-gcp-trace < 2",
"pydantic >= 2.6.3, < 3",
]

Expand All @@ -149,9 +151,10 @@
]

langchain_extra_require = [
"langchain >= 0.1.16, < 0.2",
"langchain >= 0.1.16, < 0.3",
"langchain-core < 0.2",
"langchain-google-vertexai < 2",
"openinference-instrumentation-langchain >= 0.1.19, < 0.2",
]

langchain_testing_extra_require = list(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from vertexai.preview import reasoning_engines
from vertexai.preview.generative_models import grounding
from vertexai.generative_models import Tool
from vertexai.reasoning_engines import _utils
import pytest


Expand Down Expand Up @@ -89,6 +90,48 @@ def mock_chatvertexai():
yield model_mock


@pytest.fixture
def cloud_trace_exporter_mock():
with mock.patch.object(
_utils,
"_import_cloud_trace_exporter_or_warn",
) as cloud_trace_exporter_mock:
yield cloud_trace_exporter_mock


@pytest.fixture
def tracer_provider_mock():
with mock.patch("opentelemetry.sdk.trace.TracerProvider") as tracer_provider_mock:
yield tracer_provider_mock


@pytest.fixture
def simple_span_processor_mock():
with mock.patch(
"opentelemetry.sdk.trace.export.SimpleSpanProcessor"
) as simple_span_processor_mock:
yield simple_span_processor_mock


@pytest.fixture
def langchain_instrumentor_mock():
with mock.patch.object(
_utils,
"_import_openinference_langchain_or_warn",
) as langchain_instrumentor_mock:
yield langchain_instrumentor_mock


@pytest.fixture
def langchain_instrumentor_none_mock():
with mock.patch.object(
_utils,
"_import_openinference_langchain_or_warn",
) as langchain_instrumentor_mock:
langchain_instrumentor_mock.return_value = None
yield langchain_instrumentor_mock


@pytest.mark.usefixtures("google_auth_mock")
class TestLangchainAgent:
def setup_method(self):
Expand Down Expand Up @@ -175,6 +218,41 @@ def test_query(self, langchain_dump_mock):
[mock.call.invoke.invoke(input={"input": "test query"}, config=None)]
)

@pytest.mark.usefixtures("caplog")
def test_enable_tracing(
self,
caplog,
cloud_trace_exporter_mock,
tracer_provider_mock,
simple_span_processor_mock,
langchain_instrumentor_mock,
):
agent = reasoning_engines.LangchainAgent(
model=_TEST_MODEL,
prompt=self.prompt,
output_parser=self.output_parser,
enable_tracing=True,
)
assert agent._instrumentor is None
agent.set_up()
assert agent._instrumentor is not None
assert (
"enable_tracing=True but proceeding with tracing disabled"
not in caplog.text
)

@pytest.mark.usefixtures("caplog")
def test_enable_tracing_warning(self, caplog, langchain_instrumentor_none_mock):
agent = reasoning_engines.LangchainAgent(
model=_TEST_MODEL,
prompt=self.prompt,
output_parser=self.output_parser,
enable_tracing=True,
)
assert agent._instrumentor is None
agent.set_up()
assert "enable_tracing=True but proceeding with tracing disabled" in caplog.text


class TestConvertToolsOrRaise:
def test_convert_tools_or_raise(self, vertexai_init_mock):
Expand Down
45 changes: 45 additions & 0 deletions vertexai/preview/reasoning_engines/templates/langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,7 @@ def __init__(
runnable_kwargs: Optional[Mapping[str, Any]] = None,
model_builder: Optional[Callable] = None,
runnable_builder: Optional[Callable] = None,
enable_tracing: bool = False,
):
"""Initializes the LangchainAgent.
Expand Down Expand Up @@ -349,6 +350,9 @@ def __init__(
for customizing the orchestration logic of the Agent based on
the model returned by `model_builder` and the rest of the input
arguments.
enable_tracing (bool):
Optional. Whether to enable tracing in Cloud Trace. Defaults to
False.
Raises:
TypeError: If there is an invalid tool (e.g. function with an input
Expand Down Expand Up @@ -376,6 +380,8 @@ def __init__(
self._model_builder = model_builder
self._runnable = None
self._runnable_builder = runnable_builder
self._instrumentor = None
self._enable_tracing = enable_tracing

def set_up(self):
"""Sets up the agent for execution of queries at runtime.
Expand All @@ -387,6 +393,44 @@ def set_up(self):
the ReasoningEngine service for deployment, as it initializes clients
that can not be serialized.
"""
if self._enable_tracing:
from vertexai.reasoning_engines import _utils

cloud_trace_exporter = _utils._import_cloud_trace_exporter_or_warn()
openinference_langchain = _utils._import_openinference_langchain_or_warn()
opentelemetry = _utils._import_opentelemetry_or_warn()
opentelemetry_sdk_trace = _utils._import_opentelemetry_sdk_trace_or_warn()
if all(
(
cloud_trace_exporter,
openinference_langchain,
opentelemetry,
opentelemetry_sdk_trace,
)
):
tracer_provider = opentelemetry.trace.get_tracer_provider()
if tracer_provider and _utils._is_noop_tracer_provider(tracer_provider):
# Set a trace provider if it has not been set.
span_exporter = cloud_trace_exporter.CloudTraceSpanExporter(
project_id=self._project,
)
span_processor = opentelemetry_sdk_trace.export.SimpleSpanProcessor(
span_exporter=span_exporter,
)
tracer_provider = opentelemetry_sdk_trace.TracerProvider(
active_span_processor=span_processor,
)
opentelemetry.trace.set_tracer_provider(tracer_provider)
self._instrumentor = openinference_langchain.LangChainInstrumentor()
self._instrumentor.instrument()
else:
from google.cloud.aiplatform import base

_LOGGER = base.Logger(__name__)
_LOGGER.warning(
"enable_tracing=True but proceeding with tracing disabled "
"because not all packages for tracing have been installed"
)
model_builder = self._model_builder or _default_model_builder
self._model = model_builder(
model_name=self._model_name,
Expand Down Expand Up @@ -422,6 +466,7 @@ def clone(self) -> "LangchainAgent":
runnable_kwargs=copy.deepcopy(self._runnable_kwargs),
model_builder=self._model_builder,
runnable_builder=self._runnable_builder,
enable_tracing=self._enable_tracing,
)

def query(
Expand Down
67 changes: 67 additions & 0 deletions vertexai/reasoning_engines/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import proto

from google.cloud.aiplatform import base
from google.protobuf import struct_pb2
from google.protobuf import json_format

Expand All @@ -36,6 +37,8 @@

JsonDict = Dict[str, Any]

_LOGGER = base.Logger(__name__)


def to_proto(
obj: Union[JsonDict, proto.Message],
Expand Down Expand Up @@ -195,6 +198,14 @@ def generate_schema(
return schema


def _is_noop_tracer_provider(tracer_provider) -> bool:
"""Returns True if the tracer_provider is Proxy or NoOp."""
opentelemetry = _import_opentelemetry_or_warn()
ProxyTracerProvider = opentelemetry.trace.ProxyTracerProvider
NoOpTracerProvider = opentelemetry.trace.NoOpTracerProvider
return isinstance(tracer_provider, (NoOpTracerProvider, ProxyTracerProvider))


def _import_cloud_storage_or_raise() -> types.ModuleType:
"""Tries to import the Cloud Storage module."""
try:
Expand Down Expand Up @@ -233,3 +244,59 @@ def _import_pydantic_or_raise() -> types.ModuleType:
"'pip install google-cloud-aiplatform[reasoningengine]'."
) from e
return pydantic


def _import_opentelemetry_or_warn() -> Optional[types.ModuleType]:
"""Tries to import the opentelemetry module."""
try:
import opentelemetry # noqa:F401

return opentelemetry
except ImportError:
_LOGGER.warning(
"opentelemetry-sdk is not installed. Please call "
"'pip install google-cloud-aiplatform[reasoningengine]'."
)
return None


def _import_opentelemetry_sdk_trace_or_warn() -> Optional[types.ModuleType]:
"""Tries to import the opentelemetry.sdk.trace module."""
try:
import opentelemetry.sdk.trace # noqa:F401

return opentelemetry.sdk.trace
except ImportError:
_LOGGER.warning(
"opentelemetry-sdk is not installed. Please call "
"'pip install google-cloud-aiplatform[reasoningengine]'."
)
return None


def _import_cloud_trace_exporter_or_warn() -> Optional[types.ModuleType]:
"""Tries to import the opentelemetry.exporter.cloud_trace module."""
try:
import opentelemetry.exporter.cloud_trace # noqa:F401

return opentelemetry.exporter.cloud_trace
except ImportError:
_LOGGER.warning(
"opentelemetry-exporter-gcp-trace is not installed. Please "
"call 'pip install google-cloud-aiplatform[langchain]'."
)
return None


def _import_openinference_langchain_or_warn() -> Optional[types.ModuleType]:
"""Tries to import the openinference.instrumentation.langchain module."""
try:
import openinference.instrumentation.langchain # noqa:F401

return openinference.instrumentation.langchain
except ImportError:
_LOGGER.warning(
"openinference-instrumentation-langchain is not installed. Please "
"call 'pip install google-cloud-aiplatform[langchain]'."
)
return None

0 comments on commit cad035c

Please sign in to comment.