Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ dev = [
"langgraph>=0.4.8",
"langchain-core>=0.3.75",
"langchain>=0.3.27",
"litellm>=1.77.0",
"litellm>=1.79.1",
"groq>=0.30.0",
"anthropic[bedrock]>=0.60.0",
"langchain-openai>=0.3.32",
Expand Down
12 changes: 7 additions & 5 deletions src/lmnr/opentelemetry_lib/decorators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import pydantic
import orjson
import types
from typing import Any, AsyncGenerator, Callable, Generator, Literal
from typing import Any, AsyncGenerator, Callable, Generator, Literal, TypeVar

from opentelemetry import context as context_api
from opentelemetry.trace import Span, Status, StatusCode
Expand All @@ -28,6 +28,8 @@

logger = get_default_logger(__name__)

F = TypeVar("F", bound=Callable[..., Any])

DEFAULT_PLACEHOLDER = {}


Expand Down Expand Up @@ -179,8 +181,8 @@ def observe_base(
input_formatter: Callable[..., str] | None = None,
output_formatter: Callable[..., str] | None = None,
preserve_global_context: bool = False,
):
def decorate(fn):
) -> Callable[[F], F]:
def decorate(fn: F) -> F:
@wraps(fn)
def wrap(*args, **kwargs):
if not TracerWrapper.verify_initialized():
Expand Down Expand Up @@ -257,8 +259,8 @@ def async_observe_base(
input_formatter: Callable[..., str] | None = None,
output_formatter: Callable[..., str] | None = None,
preserve_global_context: bool = False,
):
def decorate(fn):
) -> Callable[[F], F]:
def decorate(fn: F) -> F:
@wraps(fn)
async def wrap(*args, **kwargs):
if not TracerWrapper.verify_initialized():
Expand Down
77 changes: 60 additions & 17 deletions src/lmnr/sdk/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
)
from opentelemetry.trace import INVALID_SPAN, get_current_span

from typing import Any, Callable, Literal, TypeVar, cast
from typing import Any, Callable, Coroutine, Literal, TypeVar, overload
from typing_extensions import ParamSpec

from lmnr.opentelemetry_lib.tracing.attributes import SESSION_ID
Expand All @@ -19,6 +19,8 @@
R = TypeVar("R")


# Overload for synchronous functions
@overload
def observe(
*,
name: str | None = None,
Expand All @@ -28,12 +30,52 @@ def observe(
ignore_output: bool = False,
span_type: Literal["DEFAULT", "LLM", "TOOL"] = "DEFAULT",
ignore_inputs: list[str] | None = None,
input_formatter: Callable[P, str] | None = None,
output_formatter: Callable[[R], str] | None = None,
input_formatter: Callable[..., str] | None = None,
output_formatter: Callable[..., str] | None = None,
metadata: dict[str, Any] | None = None,
tags: list[str] | None = None,
preserve_global_context: bool = False,
) -> Callable[[Callable[P, R]], Callable[P, R]]:
) -> Callable[[Callable[P, R]], Callable[P, R]]: ...


# Overload for asynchronous functions
@overload
def observe(
*,
name: str | None = None,
session_id: str | None = None,
user_id: str | None = None,
ignore_input: bool = False,
ignore_output: bool = False,
span_type: Literal["DEFAULT", "LLM", "TOOL"] = "DEFAULT",
ignore_inputs: list[str] | None = None,
input_formatter: Callable[..., str] | None = None,
output_formatter: Callable[..., str] | None = None,
metadata: dict[str, Any] | None = None,
tags: list[str] | None = None,
preserve_global_context: bool = False,
) -> Callable[
[Callable[P, Coroutine[Any, Any, R]]], Callable[P, Coroutine[Any, Any, R]]
]: ...


# Implementation
def observe(
*,
name: str | None = None,
session_id: str | None = None,
user_id: str | None = None,
ignore_input: bool = False,
ignore_output: bool = False,
span_type: Literal["DEFAULT", "LLM", "TOOL"] = "DEFAULT",
ignore_inputs: list[str] | None = None,
input_formatter: Callable[..., str] | None = None,
output_formatter: Callable[..., str] | None = None,
metadata: dict[str, Any] | None = None,
tags: list[str] | None = None,
preserve_global_context: bool = False,
):
# Return type is determined by overloads above
"""The main decorator entrypoint for Laminar. This is used to wrap
functions and methods to create spans.

Expand All @@ -57,14 +99,15 @@ def foo(a, b, `sensitive_data`), and you want to ignore the\
`sensitive_data` argument, you can pass ["sensitive_data"] to\
this argument. Defaults to None.
input_formatter (Callable[P, str] | None, optional): A custom function\
to format the input of the wrapped function. All function arguments\
are passed to this function. Must return a string. Ignored if\
to format the input of the wrapped function. This function should\
accept the same parameters as the wrapped function and return a string.\
All function arguments are passed to this function. Ignored if\
`ignore_input` is True. Does not respect `ignore_inputs` argument.
Defaults to None.
output_formatter (Callable[[R], str] | None, optional): A custom function\
to format the output of the wrapped function. The output is passed\
to this function. Must return a string. Ignored if `ignore_output`
is True. Does not respect `ignore_inputs` argument.
to format the output of the wrapped function. This function should\
accept a single parameter (the return value of the wrapped function)\
and return a string. Ignored if `ignore_output` is True.\
Defaults to None.
metadata (dict[str, Any] | None, optional): Metadata to associate with\
the trace. Must be JSON serializable. Defaults to None.
Expand All @@ -81,7 +124,9 @@ def foo(a, b, `sensitive_data`), and you want to ignore the\
R: Returns the result of the wrapped function
"""

def decorator(func: Callable) -> Callable:
def decorator(
func: Callable[P, R] | Callable[P, Coroutine[Any, Any, R]],
) -> Callable[P, R] | Callable[P, Coroutine[Any, Any, R]]:
current_span = get_current_span()
if current_span != INVALID_SPAN:
if session_id is not None:
Expand Down Expand Up @@ -127,8 +172,8 @@ def decorator(func: Callable) -> Callable:
" is ignored because `ignore_output` is True. Specify only one of"
" `ignore_output` or `output_formatter`."
)
result = (
async_observe_base(
if is_async(func):
return async_observe_base(
name=name,
ignore_input=ignore_input,
ignore_output=ignore_output,
Expand All @@ -139,8 +184,8 @@ def decorator(func: Callable) -> Callable:
association_properties=association_properties,
preserve_global_context=preserve_global_context,
)(func)
if is_async(func)
else observe_base(
else:
return observe_base(
name=name,
ignore_input=ignore_input,
ignore_output=ignore_output,
Expand All @@ -151,7 +196,5 @@ def decorator(func: Callable) -> Callable:
association_properties=association_properties,
preserve_global_context=preserve_global_context,
)(func)
)
return result

return cast(Callable, decorator)
return decorator