Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 13 additions & 7 deletions src/openai/lib/streaming/_assistants.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import httpx

from ..._utils import is_dict, is_list, consume_sync_iterator, consume_async_iterator
from ..._utils import consume_sync_iterator, consume_async_iterator
from ..._compat import model_dump
from ..._models import construct_type
from ..._streaming import Stream, AsyncStream
Expand Down Expand Up @@ -994,25 +994,31 @@ def accumulate_delta(acc: dict[object, object], delta: dict[object, object]) ->
#
# the same applies to `type` properties as they're used for
# discriminated unions
if key == "index" or key == "type":
if key in {"index", "type"}:
acc[key] = delta_value
continue

if isinstance(acc_value, str) and isinstance(delta_value, str):
acc_value += delta_value
elif isinstance(acc_value, (int, float)) and isinstance(delta_value, (int, float)):
acc_value += delta_value
elif is_dict(acc_value) and is_dict(delta_value):
elif isinstance(acc_value, dict) and isinstance(delta_value, dict):
acc_value = accumulate_delta(acc_value, delta_value)
elif is_list(acc_value) and is_list(delta_value):
elif isinstance(acc_value, list) and isinstance(delta_value, list):
# for lists of non-dictionary items we'll only ever get new entries
# in the array, existing entries will never be changed
if all(isinstance(x, (str, int, float)) for x in acc_value):
# Fast check for homogeneous types
acc_value_is_strintfloat = True
for x in acc_value:
if not isinstance(x, (str, int, float)):
acc_value_is_strintfloat = False
break
if acc_value_is_strintfloat:
acc_value.extend(delta_value)
continue

for delta_entry in delta_value:
if not is_dict(delta_entry):
if not isinstance(delta_entry, dict):
raise TypeError(f"Unexpected list delta entry is not a dictionary: {delta_entry}")

try:
Expand All @@ -1028,7 +1034,7 @@ def accumulate_delta(acc: dict[object, object], delta: dict[object, object]) ->
except IndexError:
acc_value.insert(index, delta_entry)
else:
if not is_dict(acc_entry):
if not isinstance(acc_entry, dict):
raise TypeError("not handled yet")

acc_value[index] = accumulate_delta(acc_entry, delta_entry)
Expand Down