From bf54efd1b8466351c91c5d5e08efca2331d6ef75 Mon Sep 17 00:00:00 2001 From: Din Date: Wed, 5 Nov 2025 11:15:01 +0000 Subject: [PATCH] fix decorator typing, bump dev dep on litellm --- pyproject.toml | 2 +- .../opentelemetry_lib/decorators/__init__.py | 12 +-- src/lmnr/sdk/decorators.py | 77 +++++++++++++++---- 3 files changed, 68 insertions(+), 23 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index c9d7c223..2fd42bf0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -132,7 +132,7 @@ dev = [ "langgraph>=0.4.8", "langchain-core>=0.3.75", "langchain>=0.3.27", - "litellm>=1.77.0", + "litellm>=1.79.1", "groq>=0.30.0", "anthropic[bedrock]>=0.60.0", "langchain-openai>=0.3.32", diff --git a/src/lmnr/opentelemetry_lib/decorators/__init__.py b/src/lmnr/opentelemetry_lib/decorators/__init__.py index 9b96f529..fb7da05e 100644 --- a/src/lmnr/opentelemetry_lib/decorators/__init__.py +++ b/src/lmnr/opentelemetry_lib/decorators/__init__.py @@ -2,7 +2,7 @@ import pydantic import orjson import types -from typing import Any, AsyncGenerator, Callable, Generator, Literal +from typing import Any, AsyncGenerator, Callable, Generator, Literal, TypeVar from opentelemetry import context as context_api from opentelemetry.trace import Span, Status, StatusCode @@ -28,6 +28,8 @@ logger = get_default_logger(__name__) +F = TypeVar("F", bound=Callable[..., Any]) + DEFAULT_PLACEHOLDER = {} @@ -179,8 +181,8 @@ def observe_base( input_formatter: Callable[..., str] | None = None, output_formatter: Callable[..., str] | None = None, preserve_global_context: bool = False, -): - def decorate(fn): +) -> Callable[[F], F]: + def decorate(fn: F) -> F: @wraps(fn) def wrap(*args, **kwargs): if not TracerWrapper.verify_initialized(): @@ -257,8 +259,8 @@ def async_observe_base( input_formatter: Callable[..., str] | None = None, output_formatter: Callable[..., str] | None = None, preserve_global_context: bool = False, -): - def decorate(fn): +) -> Callable[[F], F]: + def decorate(fn: F) -> F: @wraps(fn) async def wrap(*args, **kwargs): if not TracerWrapper.verify_initialized(): diff --git a/src/lmnr/sdk/decorators.py b/src/lmnr/sdk/decorators.py index d8fb1390..ae700a0e 100644 --- a/src/lmnr/sdk/decorators.py +++ b/src/lmnr/sdk/decorators.py @@ -5,7 +5,7 @@ ) from opentelemetry.trace import INVALID_SPAN, get_current_span -from typing import Any, Callable, Literal, TypeVar, cast +from typing import Any, Callable, Coroutine, Literal, TypeVar, overload from typing_extensions import ParamSpec from lmnr.opentelemetry_lib.tracing.attributes import SESSION_ID @@ -19,6 +19,8 @@ R = TypeVar("R") +# Overload for synchronous functions +@overload def observe( *, name: str | None = None, @@ -28,12 +30,52 @@ def observe( ignore_output: bool = False, span_type: Literal["DEFAULT", "LLM", "TOOL"] = "DEFAULT", ignore_inputs: list[str] | None = None, - input_formatter: Callable[P, str] | None = None, - output_formatter: Callable[[R], str] | None = None, + input_formatter: Callable[..., str] | None = None, + output_formatter: Callable[..., str] | None = None, metadata: dict[str, Any] | None = None, tags: list[str] | None = None, preserve_global_context: bool = False, -) -> Callable[[Callable[P, R]], Callable[P, R]]: +) -> Callable[[Callable[P, R]], Callable[P, R]]: ... + + +# Overload for asynchronous functions +@overload +def observe( + *, + name: str | None = None, + session_id: str | None = None, + user_id: str | None = None, + ignore_input: bool = False, + ignore_output: bool = False, + span_type: Literal["DEFAULT", "LLM", "TOOL"] = "DEFAULT", + ignore_inputs: list[str] | None = None, + input_formatter: Callable[..., str] | None = None, + output_formatter: Callable[..., str] | None = None, + metadata: dict[str, Any] | None = None, + tags: list[str] | None = None, + preserve_global_context: bool = False, +) -> Callable[ + [Callable[P, Coroutine[Any, Any, R]]], Callable[P, Coroutine[Any, Any, R]] +]: ... + + +# Implementation +def observe( + *, + name: str | None = None, + session_id: str | None = None, + user_id: str | None = None, + ignore_input: bool = False, + ignore_output: bool = False, + span_type: Literal["DEFAULT", "LLM", "TOOL"] = "DEFAULT", + ignore_inputs: list[str] | None = None, + input_formatter: Callable[..., str] | None = None, + output_formatter: Callable[..., str] | None = None, + metadata: dict[str, Any] | None = None, + tags: list[str] | None = None, + preserve_global_context: bool = False, +): + # Return type is determined by overloads above """The main decorator entrypoint for Laminar. This is used to wrap functions and methods to create spans. @@ -57,14 +99,15 @@ def foo(a, b, `sensitive_data`), and you want to ignore the\ `sensitive_data` argument, you can pass ["sensitive_data"] to\ this argument. Defaults to None. input_formatter (Callable[P, str] | None, optional): A custom function\ - to format the input of the wrapped function. All function arguments\ - are passed to this function. Must return a string. Ignored if\ + to format the input of the wrapped function. This function should\ + accept the same parameters as the wrapped function and return a string.\ + All function arguments are passed to this function. Ignored if\ `ignore_input` is True. Does not respect `ignore_inputs` argument. Defaults to None. output_formatter (Callable[[R], str] | None, optional): A custom function\ - to format the output of the wrapped function. The output is passed\ - to this function. Must return a string. Ignored if `ignore_output` - is True. Does not respect `ignore_inputs` argument. + to format the output of the wrapped function. This function should\ + accept a single parameter (the return value of the wrapped function)\ + and return a string. Ignored if `ignore_output` is True.\ Defaults to None. metadata (dict[str, Any] | None, optional): Metadata to associate with\ the trace. Must be JSON serializable. Defaults to None. @@ -81,7 +124,9 @@ def foo(a, b, `sensitive_data`), and you want to ignore the\ R: Returns the result of the wrapped function """ - def decorator(func: Callable) -> Callable: + def decorator( + func: Callable[P, R] | Callable[P, Coroutine[Any, Any, R]], + ) -> Callable[P, R] | Callable[P, Coroutine[Any, Any, R]]: current_span = get_current_span() if current_span != INVALID_SPAN: if session_id is not None: @@ -127,8 +172,8 @@ def decorator(func: Callable) -> Callable: " is ignored because `ignore_output` is True. Specify only one of" " `ignore_output` or `output_formatter`." ) - result = ( - async_observe_base( + if is_async(func): + return async_observe_base( name=name, ignore_input=ignore_input, ignore_output=ignore_output, @@ -139,8 +184,8 @@ def decorator(func: Callable) -> Callable: association_properties=association_properties, preserve_global_context=preserve_global_context, )(func) - if is_async(func) - else observe_base( + else: + return observe_base( name=name, ignore_input=ignore_input, ignore_output=ignore_output, @@ -151,7 +196,5 @@ def decorator(func: Callable) -> Callable: association_properties=association_properties, preserve_global_context=preserve_global_context, )(func) - ) - return result - return cast(Callable, decorator) + return decorator