diff --git a/langfuse/openai.py b/langfuse/openai.py index 9601715e5..193f81aa0 100644 --- a/langfuse/openai.py +++ b/langfuse/openai.py @@ -89,6 +89,7 @@ def __init__( trace_id=None, session_id=None, user_id=None, + tags=None, parent_observation_id=None, **kwargs, ): @@ -98,6 +99,7 @@ def __init__( self.args["trace_id"] = trace_id self.args["session_id"] = session_id self.args["user_id"] = user_id + self.args["tags"] = tags self.args["parent_observation_id"] = parent_observation_id self.kwargs = kwargs @@ -187,6 +189,12 @@ def _get_langfuse_data_from_kwargs( if user_id is not None and not isinstance(user_id, str): raise TypeError("user_id must be a string") + tags = kwargs.get("tags", None) + if tags is not None and ( + not isinstance(tags, list) or not all(isinstance(tag, str) for tag in tags) + ): + raise TypeError("tags must be a list of strings") + parent_observation_id = kwargs.get("parent_observation_id", None) if parent_observation_id is not None and not isinstance(parent_observation_id, str): raise TypeError("parent_observation_id must be a string") @@ -194,10 +202,11 @@ def _get_langfuse_data_from_kwargs( raise ValueError("parent_observation_id requires trace_id to be set") if trace_id: - langfuse.trace(id=trace_id, session_id=session_id, user_id=user_id) - elif session_id: - # If a session_id is provided but no trace_id, we should create a trace using the SDK and then use its trace_id - trace_id = langfuse.trace(session_id=session_id, user_id=user_id).id + langfuse.trace(id=trace_id, session_id=session_id, user_id=user_id, tags=tags) + else: + trace_id = langfuse.trace( + session_id=session_id, user_id=user_id, tags=tags, name=name + ).id metadata = kwargs.get("metadata", {}) diff --git a/tests/test_openai.py b/tests/test_openai.py index 8f07af3f8..e349b99db 100644 --- a/tests/test_openai.py +++ b/tests/test_openai.py @@ -255,9 +255,11 @@ def test_openai_chat_completion_fail(): openai.api_key = os.environ["OPENAI_API_KEY"] -def test_openai_chat_completion_with_user_id(): +def test_openai_chat_completion_with_additional_params(): api = get_api() user_id = create_uuid() + session_id = create_uuid() + tags = ["tag1", "tag2"] trace_id = create_uuid() completion = chat_func( name="user-creation", @@ -267,14 +269,18 @@ def test_openai_chat_completion_with_user_id(): metadata={"someKey": "someResponse"}, user_id=user_id, trace_id=trace_id, + session_id=session_id, + tags=tags, ) openai.flush_langfuse() assert len(completion.choices) != 0 - traces = api.trace.get(trace_id) + trace = api.trace.get(trace_id) - assert traces.user_id == user_id + assert trace.user_id == user_id + assert trace.session_id == session_id + assert trace.tags == tags def test_openai_chat_completion_without_extra_param():