Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 17 additions & 37 deletions haystack/components/generators/chat/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,15 +402,9 @@ def _handle_stream_response(self, chat_completion: Stream, callback: SyncStreami
chunk_delta: StreamingChunk

for chunk in chat_completion: # pylint: disable=not-an-iterable
# choices is an empty array for usage_chunk when include_usage is set to True
if chunk.usage is not None:
chunk_delta = self._convert_usage_chunk_to_streaming_chunk(chunk)

else:
assert len(chunk.choices) == 1, "Streaming responses should have only one choice."
chunk_delta = self._convert_chat_completion_chunk_to_streaming_chunk(chunk)
assert len(chunk.choices) <= 1, "Streaming responses should have at most one choice."
chunk_delta = self._convert_chat_completion_chunk_to_streaming_chunk(chunk)
chunks.append(chunk_delta)

callback(chunk_delta)
return [self._convert_streaming_chunks_to_chat_message(chunk, chunks)]

Expand All @@ -422,15 +416,9 @@ async def _handle_async_stream_response(
chunk_delta: StreamingChunk

async for chunk in chat_completion: # pylint: disable=not-an-iterable
# choices is an empty array for usage_chunk when include_usage is set to True
if chunk.usage is not None:
chunk_delta = self._convert_usage_chunk_to_streaming_chunk(chunk)

else:
assert len(chunk.choices) == 1, "Streaming responses should have only one choice."
chunk_delta = self._convert_chat_completion_chunk_to_streaming_chunk(chunk)
assert len(chunk.choices) <= 1, "Streaming responses should have at most one choice."
chunk_delta = self._convert_chat_completion_chunk_to_streaming_chunk(chunk)
chunks.append(chunk_delta)

await callback(chunk_delta)
return [self._convert_streaming_chunks_to_chat_message(chunk, chunks)]

Expand All @@ -450,12 +438,12 @@ def _check_finish_reason(self, meta: Dict[str, Any]) -> None:
)

def _convert_streaming_chunks_to_chat_message(
self, chunk: ChatCompletionChunk, chunks: List[StreamingChunk]
self, last_chunk: ChatCompletionChunk, chunks: List[StreamingChunk]
) -> ChatMessage:
"""
Connects the streaming chunks into a single ChatMessage.

:param chunk: The last chunk returned by the OpenAI API.
:param last_chunk: The last chunk returned by the OpenAI API.
:param chunks: The list of all `StreamingChunk` objects.

:returns: The ChatMessage.
Expand Down Expand Up @@ -498,15 +486,18 @@ def _convert_streaming_chunks_to_chat_message(
_arguments=call_data["arguments"],
)

# finish_reason is in the last chunk if usage is not included, and in the second last chunk if usage is included
finish_reason = (chunks[-2] if chunk.usage and len(chunks) >= 2 else chunks[-1]).meta.get("finish_reason")
# finish_reason can appear in different places so we look for the last one
finish_reasons = [
chunk.meta.get("finish_reason") for chunk in chunks if chunk.meta.get("finish_reason") is not None
]
finish_reason = finish_reasons[-1] if finish_reasons else None

meta = {
"model": chunk.model,
"model": last_chunk.model,
"index": 0,
"finish_reason": finish_reason,
"completion_start_time": chunks[0].meta.get("received_at"), # first chunk received
"usage": chunk.usage or {},
"usage": dict(last_chunk.usage or {}), # last chunk has the final usage data if available
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I updated this to follow what we did in our normal chat completion processing where we store the dict version of the usage otherwise it's returned as an openai Python type

}

return ChatMessage.from_assistant(text=text or None, tool_calls=tool_calls, meta=meta)
Expand Down Expand Up @@ -558,6 +549,10 @@ def _convert_chat_completion_chunk_to_streaming_chunk(self, chunk: ChatCompletio
:returns:
The StreamingChunk.
"""
# if there are no choices, return an empty chunk
if len(chunk.choices) == 0:
return StreamingChunk(content="", meta={"model": chunk.model, "received_at": datetime.now().isoformat()})

# we stream the content of the chunk if it's not a tool or function call
choice: ChunkChoice = chunk.choices[0]
content = choice.delta.content or ""
Expand All @@ -574,18 +569,3 @@ def _convert_chat_completion_chunk_to_streaming_chunk(self, chunk: ChatCompletio
}
)
return chunk_message

def _convert_usage_chunk_to_streaming_chunk(self, chunk: ChatCompletionChunk) -> StreamingChunk:
"""
Converts the usage chunk received from the OpenAI API when `include_usage` is set to `True` to a StreamingChunk.

:param chunk: The usage chunk returned by the OpenAI API.
Comment on lines -578 to -582
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed this function since we didn't actually use the processed usage data from this when constructing the final ChatMessage. See the updated _convert_streaming_chunks_to_chat_message where we actually use the native chunk from OpenAI to provide the usage data.

Copy link
Contributor

@Amnah199 Amnah199 Mar 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, I'll run it locally for the original use case behind this PR for verification that this works.

Copy link
Contributor Author

@sjrl sjrl Mar 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've also added a unit test and an integration test testing for usage stats when streaming and include_usage: True so I hope that covers it, but let me know if you have any issues


:returns:
The StreamingChunk.
"""
chunk_message = StreamingChunk(content="")
chunk_message.meta.update(
{"model": chunk.model, "usage": chunk.usage, "received_at": datetime.now().isoformat()}
)
return chunk_message
59 changes: 56 additions & 3 deletions test/components/generators/chat/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from openai import OpenAIError
from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessage, ChatCompletionMessageToolCall
from openai.types.chat.chat_completion import Choice
from openai.types.completion_usage import CompletionTokensDetails, CompletionUsage, PromptTokensDetails
from openai.types.chat.chat_completion_message_tool_call import Function
from openai.types.chat import chat_completion_chunk

Expand Down Expand Up @@ -397,7 +398,15 @@ def test_run_with_tools(self, tools):
)
],
created=int(datetime.now().timestamp()),
usage={"prompt_tokens": 57, "completion_tokens": 40, "total_tokens": 97},
usage=CompletionUsage(
completion_tokens=40,
prompt_tokens=57,
total_tokens=97,
completion_tokens_details=CompletionTokensDetails(
accepted_prediction_tokens=0, audio_tokens=0, reasoning_tokens=0, rejected_prediction_tokens=0
),
prompt_tokens_details=PromptTokensDetails(audio_tokens=0, cached_tokens=0),
),
)

mock_chat_completion_create.return_value = completion
Expand All @@ -423,6 +432,7 @@ def test_run_with_tools(self, tools):
assert tool_call.tool_name == "weather"
assert tool_call.arguments == {"city": "Paris"}
assert message.meta["finish_reason"] == "tool_calls"
assert message.meta["usage"]["completion_tokens"] == 40

def test_run_with_tools_streaming(self, mock_chat_completion_chunk_with_tools, tools):
streaming_callback_called = False
Expand Down Expand Up @@ -481,7 +491,15 @@ def test_invalid_tool_call_json(self, tools, caplog):
)
],
created=1234567890,
usage={"prompt_tokens": 50, "completion_tokens": 30, "total_tokens": 80},
usage=CompletionUsage(
completion_tokens=47,
prompt_tokens=540,
total_tokens=587,
completion_tokens_details=CompletionTokensDetails(
accepted_prediction_tokens=0, audio_tokens=0, reasoning_tokens=0, rejected_prediction_tokens=0
),
prompt_tokens_details=PromptTokensDetails(audio_tokens=0, cached_tokens=0),
),
)

component = OpenAIChatGenerator(api_key=Secret.from_token("test-api-key"), tools=tools)
Expand All @@ -491,6 +509,8 @@ def test_invalid_tool_call_json(self, tools, caplog):
message = response["replies"][0]
assert len(message.tool_calls) == 0
assert "OpenAI returned a malformed JSON string for tool call arguments" in caplog.text
assert message.meta["finish_reason"] == "tool_calls"
assert message.meta["usage"]["completion_tokens"] == 47

def test_convert_streaming_chunks_to_chat_message_tool_calls_in_any_chunk(self):
component = OpenAIChatGenerator(api_key=Secret.from_token("test-api-key"))
Expand Down Expand Up @@ -782,6 +802,31 @@ def test_convert_streaming_chunks_to_chat_message_tool_calls_in_any_chunk(self):
assert result.meta["index"] == 0
assert result.meta["completion_start_time"] == "2025-02-19T16:02:55.910076"

def test_convert_usage_chunk_to_streaming_chunk(self):
component = OpenAIChatGenerator(api_key=Secret.from_token("test-api-key"))
chunk = ChatCompletionChunk(
id="chatcmpl-BC1y4wqIhe17R8sv3lgLcWlB4tXCw",
choices=[],
created=1742207200,
model="gpt-4o-mini-2024-07-18",
object="chat.completion.chunk",
service_tier="default",
system_fingerprint="fp_06737a9306",
usage=CompletionUsage(
completion_tokens=8,
prompt_tokens=13,
total_tokens=21,
completion_tokens_details=CompletionTokensDetails(
accepted_prediction_tokens=0, audio_tokens=0, reasoning_tokens=0, rejected_prediction_tokens=0
),
prompt_tokens_details=PromptTokensDetails(audio_tokens=0, cached_tokens=0),
),
)
result = component._convert_chat_completion_chunk_to_streaming_chunk(chunk)
assert result.content == ""
assert result.meta["model"] == "gpt-4o-mini-2024-07-18"
assert result.meta["received_at"] is not None

@pytest.mark.skipif(
not os.environ.get("OPENAI_API_KEY", None),
reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.",
Expand All @@ -796,6 +841,7 @@ def test_live_run(self):
assert "Paris" in message.text
assert "gpt-4o" in message.meta["model"]
assert message.meta["finish_reason"] == "stop"
assert message.meta["usage"]["prompt_tokens"] > 0

@pytest.mark.skipif(
not os.environ.get("OPENAI_API_KEY", None),
Expand Down Expand Up @@ -823,7 +869,9 @@ def __call__(self, chunk: StreamingChunk) -> None:
self.responses += chunk.content if chunk.content else ""

callback = Callback()
component = OpenAIChatGenerator(streaming_callback=callback)
component = OpenAIChatGenerator(
streaming_callback=callback, generation_kwargs={"stream_options": {"include_usage": True}}
)
results = component.run([ChatMessage.from_user("What's the capital of France?")])

assert len(results["replies"]) == 1
Expand All @@ -840,6 +888,11 @@ def __call__(self, chunk: StreamingChunk) -> None:
assert "completion_start_time" in message.meta
assert datetime.fromisoformat(message.meta["completion_start_time"]) < datetime.now()

assert isinstance(message.meta["usage"], dict)
assert message.meta["usage"]["prompt_tokens"] > 0
assert message.meta["usage"]["completion_tokens"] > 0
assert message.meta["usage"]["total_tokens"] > 0

@pytest.mark.skipif(
not os.environ.get("OPENAI_API_KEY", None),
reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.",
Expand Down
21 changes: 19 additions & 2 deletions test/components/generators/chat/test_openai_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from openai.types.chat import ChatCompletion, ChatCompletionMessage, ChatCompletionMessageToolCall, ChatCompletionChunk
from openai.types.chat.chat_completion import Choice
from openai.types.completion_usage import CompletionTokensDetails, CompletionUsage, PromptTokensDetails
from openai.types.chat.chat_completion_message_tool_call import Function
from openai.types.chat import chat_completion_chunk

Expand Down Expand Up @@ -207,7 +208,15 @@ async def test_run_with_tools_async(self, tools):
)
],
created=int(datetime.now().timestamp()),
usage={"prompt_tokens": 57, "completion_tokens": 40, "total_tokens": 97},
usage=CompletionUsage(
completion_tokens=40,
prompt_tokens=57,
total_tokens=97,
completion_tokens_details=CompletionTokensDetails(
accepted_prediction_tokens=0, audio_tokens=0, reasoning_tokens=0, rejected_prediction_tokens=0
),
prompt_tokens_details=PromptTokensDetails(audio_tokens=0, cached_tokens=0),
),
)

mock_chat_completion_create.return_value = completion
Expand All @@ -233,6 +242,7 @@ async def test_run_with_tools_async(self, tools):
assert tool_call.tool_name == "weather"
assert tool_call.arguments == {"city": "Paris"}
assert message.meta["finish_reason"] == "tool_calls"
assert message.meta["usage"]["completion_tokens"] == 40

@pytest.mark.asyncio
async def test_run_with_tools_streaming_async(self, mock_chat_completion_chunk_with_tools, tools):
Expand Down Expand Up @@ -310,7 +320,9 @@ async def callback(chunk: StreamingChunk):
counter += 1
responses += chunk.content if chunk.content else ""

component = OpenAIChatGenerator(streaming_callback=callback)
component = OpenAIChatGenerator(
streaming_callback=callback, generation_kwargs={"stream_options": {"include_usage": True}}
)
results = await component.run_async([ChatMessage.from_user("What's the capital of France?")])

assert len(results["replies"]) == 1
Expand All @@ -327,6 +339,11 @@ async def callback(chunk: StreamingChunk):
assert "completion_start_time" in message.meta
assert datetime.fromisoformat(message.meta["completion_start_time"]) < datetime.now()

assert isinstance(message.meta["usage"], dict)
assert message.meta["usage"]["prompt_tokens"] > 0
assert message.meta["usage"]["completion_tokens"] > 0
assert message.meta["usage"]["total_tokens"] > 0

@pytest.mark.skipif(
not os.environ.get("OPENAI_API_KEY", None),
reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.",
Expand Down
Loading