Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 11 additions & 12 deletions langfuse/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@


from langfuse import Langfuse
from langfuse.client import InitialGeneration, CreateTrace
from langfuse.client import InitialGeneration, CreateTrace, StatefulGenerationClient

from distutils.version import StrictVersion
import openai
from wrapt import wrap_function_wrapper

from langfuse.model import UpdateGeneration


class OpenAiDefinition:
module: str
Expand Down Expand Up @@ -128,7 +130,7 @@ def _get_langfuse_data_from_kwargs(resource: OpenAiDefinition, langfuse: Langfus
return InitialGeneration(name=name, metadata=metadata, trace_id=trace_id, start_time=start_time, prompt=prompt, modelParameters=modelParameters, model=model)


def _get_lagnfuse_data_from_streaming_response(resource: OpenAiDefinition, response, generation: InitialGeneration, langfuse: Langfuse):
def _get_lagnfuse_data_from_streaming_response(resource: OpenAiDefinition, response, generation: StatefulGenerationClient, langfuse: Langfuse):
final_response = [] if resource.type == "chat" else ""
model = None
completion_start_time = None
Expand Down Expand Up @@ -180,12 +182,10 @@ def get_response_for_chat():
return final_response[-1]["tool_calls"]
return None

new_generation = generation.copy(
update={"end_time": datetime.now(), "completion": get_response_for_chat() if resource.type == "chat" else final_response, "completion_start_time": completion_start_time}
)
update = UpdateGeneration(end_time=datetime.now(), completion=get_response_for_chat() if resource.type == "chat" else final_response, completion_start_time=completion_start_time)
if model is not None:
new_generation = new_generation.copy(update={"model": model})
langfuse.generation(new_generation)
update = update.copy(update={"model": model})
generation.update(update)


def _get_langfuse_data_from_default_response(resource: OpenAiDefinition, response):
Expand Down Expand Up @@ -225,21 +225,20 @@ def _wrap(open_ai_resource: OpenAiDefinition, langfuse: Langfuse, initialize, wr
arg_extractor = OpenAiArgsExtractor(*args, **kwargs)

generation = _get_langfuse_data_from_kwargs(open_ai_resource, new_langfuse, start_time, arg_extractor.get_langfuse_args())
updated_generation = generation
generation = new_langfuse.generation(generation)
try:
openai_response = wrapped(**arg_extractor.get_openai_args())

if _is_streaming_response(openai_response):
return _get_lagnfuse_data_from_streaming_response(open_ai_resource, openai_response, updated_generation, new_langfuse)
return _get_lagnfuse_data_from_streaming_response(open_ai_resource, openai_response, generation, new_langfuse)

else:
model, completion, usage = _get_langfuse_data_from_default_response(open_ai_resource, openai_response.__dict__ if _is_openai_v1() else openai_response)
updated_generation = generation.copy(update={"model": model, "completion": completion, "end_time": datetime.now(), "usage": usage})
new_langfuse.generation(updated_generation)
generation.update(UpdateGeneration(model=model, completion=completion, end_time=datetime.now(), usage=usage))
return openai_response
except Exception as ex:
model = kwargs.get("model", None)
new_langfuse.generation(updated_generation.copy(update={"end_time": datetime.now(), "status_message": str(ex), "level": "ERROR", "model": model}))
generation.update(UpdateGeneration(endTime=datetime.now(), statusMessage=str(ex), level="ERROR", model=model))
raise ex


Expand Down