diff --git a/guardrails/async_guard.py b/guardrails/async_guard.py index d69be04d0..0b824fa70 100644 --- a/guardrails/async_guard.py +++ b/guardrails/async_guard.py @@ -1,6 +1,7 @@ from builtins import id as object_id import contextvars import inspect +from opentelemetry import context as otel_context from typing import ( Any, AsyncIterable, @@ -38,6 +39,7 @@ ) from guardrails.types.pydantic import ModelOrListOfModels from guardrails.types.validator import UseManyValidatorSpec, UseValidatorSpec +from guardrails.utils.telemetry_utils import wrap_with_otel_context from guardrails.utils.validator_utils import verify_metadata_requirements from guardrails.validator_base import Validator @@ -320,8 +322,13 @@ async def __exec( return result # type: ignore guard_context = contextvars.Context() + # get the current otel context and wrap the subsequent call + # to preserve otel context if guard call is being called by another + # framework upstream + current_otel_context = otel_context.get_current() + wrapped__exec = wrap_with_otel_context(current_otel_context, __exec) return await guard_context.run( - __exec, + wrapped__exec, self, llm_api=llm_api, llm_output=llm_output, diff --git a/guardrails/guard.py b/guardrails/guard.py index d3c494c31..e37e130d2 100644 --- a/guardrails/guard.py +++ b/guardrails/guard.py @@ -26,6 +26,7 @@ SimpleTypes, ValidationOutcome as IValidationOutcome, ) +from opentelemetry import context as otel_context from pydantic import field_validator from pydantic.config import ConfigDict @@ -67,6 +68,7 @@ from guardrails.utils.naming_utils import random_id from guardrails.utils.api_utils import extract_serializeable_metadata from guardrails.utils.hub_telemetry_utils import HubTelemetry +from guardrails.utils.telemetry_utils import wrap_with_otel_context from guardrails.utils.validator_utils import ( get_validator, parse_validator_reference, @@ -759,8 +761,15 @@ def __exec( ) guard_context = contextvars.Context() + + # get the current otel context and wrap the subsequent call + # to preserve otel context if guard call is being called be another + # framework upstream + current_otel_context = otel_context.get_current() + wrapped__exec = wrap_with_otel_context(current_otel_context, __exec) + return guard_context.run( - __exec, + wrapped__exec, self, llm_api=llm_api, llm_output=llm_output, diff --git a/guardrails/utils/telemetry_utils.py b/guardrails/utils/telemetry_utils.py index 1972ab86e..257980347 100644 --- a/guardrails/utils/telemetry_utils.py +++ b/guardrails/utils/telemetry_utils.py @@ -1,7 +1,7 @@ import sys from functools import wraps from operator import attrgetter -from typing import Any, List, Optional, Union +from typing import Any, Callable, List, Optional, Union from opentelemetry import context from opentelemetry.context import Context @@ -237,6 +237,39 @@ async def to_trace_or_not_to_trace(*args, **kwargs): return trace_wrapper +def wrap_with_otel_context( + outer_scope_otel_context: Context, func: Callable[..., Any] +) -> Callable[..., Any]: + """This function is designed to ensure that a given OpenTelemetry context + is applied when executing a specified function. It is particularly useful + for preserving the trace context when a guardrails is executed in a + different execution flow or when integrating with other frameworks. + + Args: + outer_scope_otel_context (Context): The OpenTelemetry context to apply + when executing the function. + func (Callable[..., Any]): The function to be executed within + the given OpenTelemetry context. + + Returns: + Callable[..., Any]: A wrapped version of 'func' that, when called, + executes with 'outer_scope_otel_context' applied. + """ + + def wrapped_func(*args: Any, **kwargs: Any) -> Any: + # Attach the specified OpenTelemetry context before executing 'func' + token = context.attach(outer_scope_otel_context) + try: + # Execute 'func' within the attached context + return func(*args, **kwargs) + finally: + # Ensure the context is detached after execution + # to maintain correct context management + context.detach(token) + + return wrapped_func + + def default_otel_collector_tracer(resource_name: str = "guardsrails"): """This is the standard otel tracer set to talk to a grpc open telemetry collector running on port 4317."""