Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 101 additions & 35 deletions langfuse/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
import types
from typing import Optional

from packaging.version import Version


from langfuse import Langfuse
from langfuse.client import InitialGeneration, CreateTrace, StatefulGenerationClient

from distutils.version import StrictVersion
import openai
from wrapt import wrap_function_wrapper

Expand All @@ -19,12 +20,14 @@ class OpenAiDefinition:
object: str
method: str
type: str
sync: bool

def __init__(self, module: str, object: str, method: str, type: str):
def __init__(self, module: str, object: str, method: str, type: str, sync: bool):
self.module = module
self.object = object
self.method = method
self.type = type
self.sync = sync


OPENAI_METHODS_V0 = [
Expand All @@ -33,28 +36,34 @@ def __init__(self, module: str, object: str, method: str, type: str):
object="ChatCompletion",
method="create",
type="chat",
sync=True,
),
OpenAiDefinition(
module="openai",
object="Completion",
method="create",
type="completion",
sync=True,
),
]


OPENAI_METHODS_V1 = [
OpenAiDefinition(module="openai.resources.chat.completions", object="Completions", method="create", type="chat", sync=True),
OpenAiDefinition(module="openai.resources.completions", object="Completions", method="create", type="completion", sync=True),
OpenAiDefinition(
module="openai.resources.chat.completions",
object="Completions",
object="AsyncCompletions",
method="create",
type="chat",
sync=False,
),
OpenAiDefinition(
module="openai.resources.completions",
object="Completions",
object="AsyncCompletions",
method="create",
type="completion",
sync=False,
),
]

Expand All @@ -75,9 +84,9 @@ def get_openai_args(self):


def _langfuse_wrapper(func):
def _with_langfuse(open_ai_definitions, langfuse, initialize):
def _with_langfuse(open_ai_definitions, initialize):
def wrapper(wrapped, instance, args, kwargs):
return func(open_ai_definitions, langfuse, initialize, wrapped, instance, args, kwargs)
return func(open_ai_definitions, initialize, wrapped, args, kwargs)

return wrapper

Expand Down Expand Up @@ -130,12 +139,41 @@ def _get_langfuse_data_from_kwargs(resource: OpenAiDefinition, langfuse: Langfus
return InitialGeneration(name=name, metadata=metadata, trace_id=trace_id, start_time=start_time, prompt=prompt, modelParameters=modelParameters, model=model)


def _get_lagnfuse_data_from_streaming_response(resource: OpenAiDefinition, response, generation: StatefulGenerationClient, langfuse: Langfuse):
final_response = [] if resource.type == "chat" else ""
def _get_lagnfuse_data_from_sync_streaming_response(resource: OpenAiDefinition, response, generation: StatefulGenerationClient, langfuse: Langfuse):
responses = []
for i in response:
responses.append(i)
yield i

model, completion_start_time, completion = _extract_data(resource, responses)

_create_langfuse_update(completion, generation, completion_start_time, model=model)


async def _get_lagnfuse_data_from_async_streaming_response(resource: OpenAiDefinition, response, generation: StatefulGenerationClient, langfuse: Langfuse):
responses = []
async for i in response:
responses.append(i)
yield i

model, completion_start_time, completion = _extract_data(resource, responses)

_create_langfuse_update(completion, generation, completion_start_time, model=model)


def _create_langfuse_update(completion, generation: StatefulGenerationClient, completion_start_time, model=None):
update = UpdateGeneration(end_time=datetime.now(), completion=completion, completion_start_time=completion_start_time)
if model is not None:
update = update.copy(update={"model": model})
generation.update(update)


def _extract_data(resource, responses):
completion = [] if resource.type == "chat" else ""
model = None
completion_start_time = None
for index, i in enumerate(response):
print(index)

for index, i in enumerate(responses):
if index == 0:
completion_start_time = datetime.now()

Expand All @@ -156,36 +194,31 @@ def _get_lagnfuse_data_from_streaming_response(resource: OpenAiDefinition, respo
delta = delta.__dict__

if delta.get("role", None) is not None:
final_response.append({"role": delta.get("role", None), "function_call": None, "tool_calls": None, "content": None})
completion.append({"role": delta.get("role", None), "function_call": None, "tool_calls": None, "content": None})

elif delta.get("content", None) is not None:
final_response[-1]["content"] = delta.get("content", None) if final_response[-1]["content"] is None else final_response[-1]["content"] + delta.get("content", None)
completion[-1]["content"] = delta.get("content", None) if completion[-1]["content"] is None else completion[-1]["content"] + delta.get("content", None)

elif delta.get("function_call", None) is not None:
final_response[-1]["function_call"] = (
delta.get("function_call", None) if final_response[-1]["function_call"] is None else final_response[-1]["function_call"] + delta.get("function_call", None)
completion[-1]["function_call"] = (
delta.get("function_call", None) if completion[-1]["function_call"] is None else completion[-1]["function_call"] + delta.get("function_call", None)
)
elif delta.get("tools_call", None) is not None:
final_response[-1]["tool_calls"] = delta.get("tools_call", None) if final_response[-1]["tool_calls"] is None else final_response[-1]["tool_calls"] + delta.get("tools_call", None)
completion[-1]["tool_calls"] = delta.get("tools_call", None) if completion[-1]["tool_calls"] is None else completion[-1]["tool_calls"] + delta.get("tools_call", None)
if resource.type == "completion":
final_response += choice.get("text", None)

yield i
completion += choice.get("text", None)

def get_response_for_chat():
if len(final_response) > 0:
if final_response[-1].get("content", None) is not None:
return final_response[-1]["content"]
elif final_response[-1].get("function_call", None) is not None:
return final_response[-1]["function_call"]
elif final_response[-1].get("tool_calls", None) is not None:
return final_response[-1]["tool_calls"]
if len(completion) > 0:
if completion[-1].get("content", None) is not None:
return completion[-1]["content"]
elif completion[-1].get("function_call", None) is not None:
return completion[-1]["function_call"]
elif completion[-1].get("tool_calls", None) is not None:
return completion[-1]["tool_calls"]
return None

update = UpdateGeneration(end_time=datetime.now(), completion=get_response_for_chat() if resource.type == "chat" else final_response, completion_start_time=completion_start_time)
if model is not None:
update = update.copy(update={"model": model})
generation.update(update)
return model, completion_start_time, get_response_for_chat() if resource.type == "chat" else completion


def _get_langfuse_data_from_default_response(resource: OpenAiDefinition, response):
Expand All @@ -210,15 +243,15 @@ def _get_langfuse_data_from_default_response(resource: OpenAiDefinition, respons


def _is_openai_v1():
return StrictVersion(openai.__version__) >= StrictVersion("1.0.0")
return Version(openai.__version__) >= Version("1.0.0")


def _is_streaming_response(response):
return isinstance(response, types.GeneratorType) or (_is_openai_v1() and isinstance(response, openai.Stream))
return isinstance(response, types.GeneratorType) or (_is_openai_v1() and isinstance(response, openai.Stream)) or (_is_openai_v1() and isinstance(response, openai.AsyncStream))


@_langfuse_wrapper
def _wrap(open_ai_resource: OpenAiDefinition, langfuse: Langfuse, initialize, wrapped, instance, args, kwargs):
def _wrap(open_ai_resource: OpenAiDefinition, initialize, wrapped, args, kwargs):
new_langfuse = initialize()

start_time = datetime.now()
Expand All @@ -230,7 +263,31 @@ def _wrap(open_ai_resource: OpenAiDefinition, langfuse: Langfuse, initialize, wr
openai_response = wrapped(**arg_extractor.get_openai_args())

if _is_streaming_response(openai_response):
return _get_lagnfuse_data_from_streaming_response(open_ai_resource, openai_response, generation, new_langfuse)
return _get_lagnfuse_data_from_sync_streaming_response(open_ai_resource, openai_response, generation, new_langfuse)

else:
model, completion, usage = _get_langfuse_data_from_default_response(open_ai_resource, openai_response.__dict__ if _is_openai_v1() else openai_response)
generation.update(UpdateGeneration(model=model, completion=completion, end_time=datetime.now(), usage=usage))
return openai_response
except Exception as ex:
model = kwargs.get("model", None)
generation.update(UpdateGeneration(endTime=datetime.now(), statusMessage=str(ex), level="ERROR", model=model))
raise ex


@_langfuse_wrapper
async def _wrap_async(open_ai_resource: OpenAiDefinition, initialize, wrapped, args, kwargs):
new_langfuse = initialize()
start_time = datetime.now()
arg_extractor = OpenAiArgsExtractor(*args, **kwargs)

generation = _get_langfuse_data_from_kwargs(open_ai_resource, new_langfuse, start_time, arg_extractor.get_langfuse_args())
generation = new_langfuse.generation(generation)
try:
openai_response = await wrapped(**arg_extractor.get_openai_args())

if _is_streaming_response(openai_response):
return _get_lagnfuse_data_from_async_streaming_response(open_ai_resource, openai_response, generation, new_langfuse)

else:
model, completion, usage = _get_langfuse_data_from_default_response(open_ai_resource, openai_response.__dict__ if _is_openai_v1() else openai_response)
Expand Down Expand Up @@ -271,15 +328,24 @@ def register_tracing(self):
wrap_function_wrapper(
resource.module,
f"{resource.object}.{resource.method}",
_wrap(resource, self._langfuse, self.initialize),
_wrap(resource, self.initialize) if resource.sync else _wrap_async(resource, self.initialize),
)

setattr(openai, "langfuse_public_key", None)
setattr(openai, "langfuse_secret_key", None)
setattr(openai, "langfuse_host", None)

setattr(openai, "flush_langfuse", self.flush)

setattr(openai.AsyncOpenAI, "langfuse_public_key", None)
setattr(openai.AsyncOpenAI, "langfuse_secret_key", None)
setattr(openai.AsyncOpenAI, "langfuse_host", None)
setattr(openai.AsyncOpenAI, "flush_langfuse", self.flush)

setattr(openai.OpenAI, "langfuse_public_key", None)
setattr(openai.OpenAI, "langfuse_secret_key", None)
setattr(openai.OpenAI, "langfuse_host", None)
setattr(openai.OpenAI, "flush_langfuse", self.flush)


modifier = OpenAILangfuse()
modifier.register_tracing()
74 changes: 74 additions & 0 deletions tests/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from langfuse.openai import _is_openai_v1, _is_streaming_response, openai

from tests.utils import create_uuid, get_api
from openai import AsyncOpenAI


chat_func = openai.chat.completions.create if _is_openai_v1() else openai.ChatCompletion.create
Expand Down Expand Up @@ -460,3 +461,76 @@ def test_fails_wrong_trace_id():
prompt="1 + 1 = ",
temperature=0,
)


@pytest.mark.asyncio
async def test_async_chat():
api = get_api()
client = AsyncOpenAI()
generation_name = create_uuid()

completion = await client.chat.completions.create(messages=[{"role": "user", "content": "1 + 1 = "}], model="gpt-3.5-turbo", name=generation_name)

client.flush_langfuse()
print(completion)

generation = api.observations.get_many(name=generation_name, type="GENERATION")

assert len(generation.data) != 0
assert generation.data[0].name == generation_name
assert len(completion.choices) != 0
assert completion.choices[0].message.content == generation.data[0].output
assert generation.data[0].input == [{"content": "1 + 1 = ", "role": "user"}]
assert generation.data[0].type == "GENERATION"
assert generation.data[0].model == "gpt-3.5-turbo-0613"
assert generation.data[0].start_time is not None
assert generation.data[0].end_time is not None
assert generation.data[0].start_time < generation.data[0].end_time
assert generation.data[0].model_parameters == {
"temperature": 1,
"top_p": 1,
"frequency_penalty": 0,
"maxTokens": "inf",
"presence_penalty": 0,
}
assert generation.data[0].prompt_tokens is not None
assert generation.data[0].completion_tokens is not None
assert generation.data[0].total_tokens is not None
assert generation.data[0].output == "2"


@pytest.mark.asyncio
async def test_async_chat_stream():
api = get_api()
client = AsyncOpenAI()
generation_name = create_uuid()

completion = await client.chat.completions.create(messages=[{"role": "user", "content": "1 + 1 = "}], model="gpt-3.5-turbo", name=generation_name, stream=True)

async for c in completion:
print(c)

client.flush_langfuse()
print(completion)

generation = api.observations.get_many(name=generation_name, type="GENERATION")

assert len(generation.data) != 0
assert generation.data[0].name == generation_name
assert generation.data[0].input == [{"content": "1 + 1 = ", "role": "user"}]
assert generation.data[0].type == "GENERATION"
assert generation.data[0].model == "gpt-3.5-turbo-0613"
assert generation.data[0].start_time is not None
assert generation.data[0].end_time is not None
assert generation.data[0].start_time < generation.data[0].end_time
assert generation.data[0].model_parameters == {
"temperature": 1,
"top_p": 1,
"frequency_penalty": 0,
"maxTokens": "inf",
"presence_penalty": 0,
}
assert generation.data[0].prompt_tokens is not None
assert generation.data[0].completion_tokens is not None
assert generation.data[0].total_tokens is not None
assert generation.data[0].output == "2"