diff --git a/guardrails/async_guard.py b/guardrails/async_guard.py index 079a6f34a..6f527aa7b 100644 --- a/guardrails/async_guard.py +++ b/guardrails/async_guard.py @@ -1,5 +1,6 @@ import contextvars import inspect +from opentelemetry import context as otel_context from typing import ( Any, AsyncIterable, @@ -28,6 +29,7 @@ set_tracer, set_tracer_context, ) +from guardrails.utils.telemetry_utils import wrap_with_otel_context class AsyncGuard(Guard): @@ -191,8 +193,12 @@ async def __call( return result guard_context = contextvars.Context() + # get the current otel context and wrap the subsequent call to preserve otel context if guard call is being called by another + # framework upstream + current_otel_context = otel_context.get_current() + wrapped__call = wrap_with_otel_context(current_otel_context, __call) return await guard_context.run( - __call, + wrapped__call, self, llm_api, prompt_params, @@ -416,10 +422,13 @@ async def __parse( *args, **kwargs, ) - guard_context = contextvars.Context() + # get the current otel context and wrap the subsequent call to preserve otel context if guard call is being called by another + # framework upstream + current_otel_context = otel_context.get_current() + wrapped__parse = wrap_with_otel_context(current_otel_context, __parse) return await guard_context.run( - __parse, + wrapped__parse, self, llm_output, metadata, diff --git a/guardrails/guard.py b/guardrails/guard.py index 969f647fe..63767e1e3 100644 --- a/guardrails/guard.py +++ b/guardrails/guard.py @@ -34,6 +34,7 @@ from guardrails_api_client.types import UNSET from langchain_core.messages import BaseMessage from langchain_core.runnables import Runnable, RunnableConfig +from opentelemetry import context as otel_context from pydantic import BaseModel from pydantic.version import VERSION as PYDANTIC_VERSION from typing_extensions import deprecated # type: ignore @@ -71,6 +72,7 @@ from guardrails.utils.hub_telemetry_utils import HubTelemetry from guardrails.utils.llm_response import LLMResponse from guardrails.utils.reask_utils import FieldReAsk +from guardrails.utils.telemetry_utils import wrap_with_otel_context from guardrails.utils.validator_utils import get_validator from guardrails.validator_base import FailResult, Validator @@ -697,8 +699,14 @@ def __call( ) guard_context = contextvars.Context() + + # get the current otel context and wrap the subsequent call to preserve otel context if guard call is being called be another + # framework upstream + current_otel_context = otel_context.get_current() + wrapped__call = wrap_with_otel_context(current_otel_context, __call) + return guard_context.run( - __call, + wrapped__call, self, llm_api, prompt_params, @@ -1059,6 +1067,10 @@ def __parse( ) guard_context = contextvars.Context() + # get the current otel context and wrap the subsequent call to preserve otel context if guard call is being called be another + # framework upstream + current_otel_context = otel_context.get_current() + wrapped__parse = wrap_with_otel_context(current_otel_context, __parse) return guard_context.run( __parse, self, diff --git a/guardrails/utils/telemetry_utils.py b/guardrails/utils/telemetry_utils.py index 35c453491..a14f8fe3b 100644 --- a/guardrails/utils/telemetry_utils.py +++ b/guardrails/utils/telemetry_utils.py @@ -1,7 +1,7 @@ import sys from functools import wraps from operator import attrgetter -from typing import Any, List, Optional, Union +from typing import Any, Callable, List, Optional, Union from opentelemetry import context from opentelemetry.context import Context @@ -246,6 +246,31 @@ async def to_trace_or_not_to_trace(*args, **kwargs): return trace_wrapper +def wrap_with_otel_context(outer_scope_otel_context: Context, func: Callable[..., Any]) -> Callable[..., Any]: + """ + This function is designed to ensure that a given OpenTelemetry context is applied when executing a specified function. + It is particularly useful for preserving the trace context when a guardrails is executed in a different execution flow or + when integrating with other frameworks. + + Args: + outer_scope_otel_context (Context): The OpenTelemetry context to apply when executing the function. + func (Callable[..., Any]): The function to be executed within the given OpenTelemetry context. + + Returns: + Callable[..., Any]: A wrapped version of 'func' that, when called, executes with 'outer_scope_otel_context' applied. + """ + def wrapped_func(*args: Any, **kwargs: Any) -> Any: + # Attach the specified OpenTelemetry context before executing 'func' + token = context.attach(outer_scope_otel_context) + try: + # Execute 'func' within the attached context + return func(*args, **kwargs) + finally: + # Ensure the context is detached after execution to maintain correct context management + context.detach(token) + return wrapped_func + + def default_otel_collector_tracer(resource_name: str = "guardsrails"): """This is the standard otel tracer set to talk to a grpc open telemetry collector running on port 4317."""