diff --git a/langfuse/openai.py b/langfuse/openai.py index a80414b3d..ac16a8397 100644 --- a/langfuse/openai.py +++ b/langfuse/openai.py @@ -5,12 +5,14 @@ from langfuse import Langfuse -from langfuse.client import InitialGeneration, CreateTrace +from langfuse.client import InitialGeneration, CreateTrace, StatefulGenerationClient from distutils.version import StrictVersion import openai from wrapt import wrap_function_wrapper +from langfuse.model import UpdateGeneration + class OpenAiDefinition: module: str @@ -128,7 +130,7 @@ def _get_langfuse_data_from_kwargs(resource: OpenAiDefinition, langfuse: Langfus return InitialGeneration(name=name, metadata=metadata, trace_id=trace_id, start_time=start_time, prompt=prompt, modelParameters=modelParameters, model=model) -def _get_lagnfuse_data_from_streaming_response(resource: OpenAiDefinition, response, generation: InitialGeneration, langfuse: Langfuse): +def _get_lagnfuse_data_from_streaming_response(resource: OpenAiDefinition, response, generation: StatefulGenerationClient, langfuse: Langfuse): final_response = [] if resource.type == "chat" else "" model = None completion_start_time = None @@ -180,12 +182,10 @@ def get_response_for_chat(): return final_response[-1]["tool_calls"] return None - new_generation = generation.copy( - update={"end_time": datetime.now(), "completion": get_response_for_chat() if resource.type == "chat" else final_response, "completion_start_time": completion_start_time} - ) + update = UpdateGeneration(end_time=datetime.now(), completion=get_response_for_chat() if resource.type == "chat" else final_response, completion_start_time=completion_start_time) if model is not None: - new_generation = new_generation.copy(update={"model": model}) - langfuse.generation(new_generation) + update = update.copy(update={"model": model}) + generation.update(update) def _get_langfuse_data_from_default_response(resource: OpenAiDefinition, response): @@ -225,21 +225,20 @@ def _wrap(open_ai_resource: OpenAiDefinition, langfuse: Langfuse, initialize, wr arg_extractor = OpenAiArgsExtractor(*args, **kwargs) generation = _get_langfuse_data_from_kwargs(open_ai_resource, new_langfuse, start_time, arg_extractor.get_langfuse_args()) - updated_generation = generation + generation = new_langfuse.generation(generation) try: openai_response = wrapped(**arg_extractor.get_openai_args()) if _is_streaming_response(openai_response): - return _get_lagnfuse_data_from_streaming_response(open_ai_resource, openai_response, updated_generation, new_langfuse) + return _get_lagnfuse_data_from_streaming_response(open_ai_resource, openai_response, generation, new_langfuse) else: model, completion, usage = _get_langfuse_data_from_default_response(open_ai_resource, openai_response.__dict__ if _is_openai_v1() else openai_response) - updated_generation = generation.copy(update={"model": model, "completion": completion, "end_time": datetime.now(), "usage": usage}) - new_langfuse.generation(updated_generation) + generation.update(UpdateGeneration(model=model, completion=completion, end_time=datetime.now(), usage=usage)) return openai_response except Exception as ex: model = kwargs.get("model", None) - new_langfuse.generation(updated_generation.copy(update={"end_time": datetime.now(), "status_message": str(ex), "level": "ERROR", "model": model})) + generation.update(UpdateGeneration(endTime=datetime.now(), statusMessage=str(ex), level="ERROR", model=model)) raise ex