Skip to content

Commit

Permalink
fix(decorators): handle generators as decorated function inputs (#640)
Browse files Browse the repository at this point in the history
* fix: handle generators as decorated function inputs
* fix(decorators): serialization of function outputs
* fix(decorators): improve handling of class methods
  • Loading branch information
hassiebp committed May 2, 2024
1 parent fc205f1 commit 2a0868c
Show file tree
Hide file tree
Showing 2 changed files with 127 additions and 22 deletions.
71 changes: 53 additions & 18 deletions langfuse/decorators/langfuse_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
cast,
)


from langfuse.client import (
Langfuse,
StatefulSpanClient,
Expand All @@ -32,6 +31,7 @@
ModelUsage,
MapValue,
)
from langfuse.serializer import EventSerializer
from langfuse.types import ObservationParams, SpanLevel
from langfuse.utils import _get_timestamp
from langfuse.utils.langfuse_singleton import LangfuseSingleton
Expand Down Expand Up @@ -159,7 +159,7 @@ async def async_wrapper(*args, **kwargs):
func_name=func.__name__,
as_type=as_type,
capture_input=capture_input,
is_instance_method=self._is_instance_method(func),
is_method=self._is_method(func),
func_args=args,
func_kwargs=kwargs,
)
Expand Down Expand Up @@ -194,7 +194,7 @@ def sync_wrapper(*args, **kwargs):
func_name=func.__name__,
as_type=as_type,
capture_input=capture_input,
is_instance_method=self._is_instance_method(func),
is_method=self._is_method(func),
func_args=args,
func_kwargs=kwargs,
)
Expand All @@ -216,25 +216,28 @@ def sync_wrapper(*args, **kwargs):
return cast(F, sync_wrapper)

@staticmethod
def _is_instance_method(func: Callable) -> bool:
"""Check if a callable is likely an instance method based on its signature.
def _is_method(func: Callable) -> bool:
"""Check if a callable is likely an class or instance method based on its signature.
This method inspects the given callable's signature for the presence of a 'self' parameter, which is conventionally used for instance methods in Python classes. It returns True if 'self' is found among the parameters, suggesting the callable is an instance method.
This method inspects the given callable's signature for the presence of a 'cls' or 'self' parameter, which is conventionally used for class and instance methods in Python classes. It returns True if 'class' or 'self' is found among the parameters, suggesting the callable is a method.
Note: This method relies on naming conventions and may not accurately identify instance methods if unconventional parameter names are used or if static or class methods incorrectly include a 'self' parameter. Additionally, during decorator execution, inspect.ismethod does not work as expected because the function has not yet been bound to an instance; it is still a function, not a method. This check attempts to infer method status based on signature, which can be useful in decorator contexts where traditional method identification techniques fail.
Note: This method relies on naming conventions and may not accurately identify instance methods if unconventional parameter names are used or if static or class methods incorrectly include a 'self' or 'cls' parameter. Additionally, during decorator execution, inspect.ismethod does not work as expected because the function has not yet been bound to an instance; it is still a function, not a method. This check attempts to infer method status based on signature, which can be useful in decorator contexts where traditional method identification techniques fail.
Returns:
bool: True if 'self' is in the callable's parameters, False otherwise.
bool: True if 'cls' or 'self' is in the callable's parameters, False otherwise.
"""
return "self" in inspect.signature(func).parameters
return (
"self" in inspect.signature(func).parameters
or "cls" in inspect.signature(func).parameters
)

def _prepare_call(
self,
*,
func_name: str,
as_type: Optional[Literal["generation"]],
capture_input: bool,
is_instance_method: bool = False,
is_method: bool = False,
func_args: Tuple = (),
func_kwargs: Dict = {},
) -> Optional[
Expand All @@ -251,14 +254,14 @@ def _prepare_call(
id = str(observation_id) if observation_id else None
start_time = _get_timestamp()

# Remove implicitly passed "self" argument for instance methods
if is_instance_method:
logged_args = func_args[1:]
else:
logged_args = func_args

input = (
{"args": logged_args, "kwargs": func_kwargs} if capture_input else None
self._get_input_from_func_args(
is_method=is_method,
func_args=func_args,
func_kwargs=func_kwargs,
)
if capture_input
else None
)

params = {
Expand Down Expand Up @@ -289,6 +292,38 @@ def _prepare_call(
except Exception as e:
self._log.error(f"Failed to prepare observation: {e}")

def _get_input_from_func_args(
self,
*,
is_method: bool = False,
func_args: Tuple = (),
func_kwargs: Dict = {},
) -> Any:
# Remove implicitly passed "self" or "cls" argument for instance or class methods
if is_method:
logged_args = func_args[1:]
else:
logged_args = func_args

# Remove generators from logged values
logged_args = [
f"<{type(arg).__name__}>"
if (inspect.isgenerator(arg) or inspect.isasyncgen(arg))
else arg
for arg in logged_args
]

logged_kwargs = {
k: (
f"<{type(v).__name__}>"
if inspect.isgenerator(v) or inspect.isasyncgen(v)
else v
)
for k, v in func_kwargs.items()
}

return {"args": logged_args, "kwargs": logged_kwargs}

def _finalize_call(
self,
observation: Optional[
Expand Down Expand Up @@ -340,7 +375,7 @@ def _handle_call_result(

end_time = observation_params["end_time"] or _get_timestamp()
output = observation_params["output"] or (
str(result) if result and capture_output else None
EventSerializer().default(result) if result and capture_output else None
)
observation_params.update(end_time=end_time, output=output)

Expand Down
78 changes: 74 additions & 4 deletions tests/test_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -651,11 +651,16 @@ def main(*args, **kwargs):
assert trace_data.observations[0].output == "manually set output"


def test_decorated_instance_methods():
mock_name = "test_decorated_instance_methods"
def test_decorated_class_and_instance_methods():
mock_name = "test_decorated_class_and_instance_methods"
mock_trace_id = create_uuid()

class TestClass:
@classmethod
@observe()
def class_method(cls, *args, **kwargs):
return "class_method"

@observe(as_type="generation")
def level_3_function(self):
langfuse_context.update_current_observation(metadata=mock_metadata)
Expand All @@ -674,6 +679,8 @@ def level_3_function(self):

@observe()
def level_2_function(self):
TestClass.class_method()

self.level_3_function()
langfuse_context.update_current_observation(metadata=mock_metadata)

Expand All @@ -697,7 +704,7 @@ def level_1_function(self, *args, **kwargs):

trace_data = get_api().trace.get(mock_trace_id)
assert (
len(trace_data.observations) == 2
len(trace_data.observations) == 3
) # Top-most function is trace, so it's not an observations

assert trace_data.input == {"args": list(mock_args), "kwargs": mock_kwargs}
Expand All @@ -716,7 +723,11 @@ def level_1_function(self, *args, **kwargs):
assert len(adjacencies) == 2 # Only trace and one observation have children

level_2_observation = adjacencies[mock_trace_id][0]
level_3_observation = adjacencies[level_2_observation.id][0]
level_3_observation = adjacencies[level_2_observation.id][1]
class_method_observation = adjacencies[level_2_observation.id][0]

assert class_method_observation.input == {"args": [], "kwargs": {}}
assert class_method_observation.output == "class_method"

assert level_2_observation.metadata == mock_metadata
assert level_3_observation.metadata == mock_deep_metadata
Expand Down Expand Up @@ -926,3 +937,62 @@ def main():
generation = trace_data.observations[0]
assert generation.type == "GENERATION"
assert generation.output == result


def test_generator_as_function_input():
mock_trace_id = create_uuid()
mock_output = "Hello, World!"

def generator_function():
yield "Hello"
yield ", "
yield "World!"

@observe()
def nested(gen):
result = ""
for item in gen:
result += item

return result

@observe()
def main(**kwargs):
gen = generator_function()

return nested(gen)

result = main(langfuse_observation_id=mock_trace_id)
langfuse_context.flush()

assert result == mock_output

trace_data = get_api().trace.get(mock_trace_id)
assert trace_data.output == mock_output

assert "<generator>" in trace_data.observations[0].input["args"]
assert trace_data.observations[0].output == "Hello, World!"

observation_start_time = trace_data.observations[0].start_time
observation_end_time = trace_data.observations[0].end_time

assert observation_start_time is not None
assert observation_end_time is not None
assert observation_start_time <= observation_end_time


def test_return_dict_for_output():
mock_trace_id = create_uuid()
mock_output = {"key": "value"}

@observe()
def function():
return mock_output

result = function(langfuse_observation_id=mock_trace_id)
langfuse_context.flush()

assert result == mock_output

trace_data = get_api().trace.get(mock_trace_id)
assert trace_data.output == mock_output

0 comments on commit 2a0868c

Please sign in to comment.