Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
151 changes: 121 additions & 30 deletions gpt_oss/responses_api/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ def generate_response(
dict[str, list[CodeInterpreterOutputLogs | CodeInterpreterOutputImage]]
] = None,
reasoning_ids: Optional[list[str]] = None,
message_ids: Optional[list[str]] = None,
treat_functions_python_as_builtin: bool = False,
) -> ResponseObject:
output = []
Expand All @@ -157,6 +158,7 @@ def generate_response(
browser_tool_index = 0
python_tool_index = 0
reasoning_ids_iter = iter(reasoning_ids or [])
message_ids_iter = iter(message_ids or [])

for entry in entries:
entry_dict = entry.to_dict()
Expand Down Expand Up @@ -296,15 +298,22 @@ def generate_response(
)
)

message_id = next(message_ids_iter, None)
output.append(
Item(
id=message_id,
type="message",
role="assistant",
content=content,
status="completed",
)
)
elif entry_dict["channel"] == "analysis":
if entry_dict.get("recipient"):
continue
author_dict = entry_dict.get("author") or {}
if author_dict.get("role") and author_dict.get("role") != "assistant":
continue
summary = []
content = [
ReasoningTextContentItem(
Expand Down Expand Up @@ -374,6 +383,7 @@ def generate_response(
)

class StreamResponsesEvents:
BROWSER_RESERVED_FUNCTIONS = {"browser.search", "browser.open", "browser.find"}
initial_tokens: list[int]
tokens: list[int]
output_tokens: list[int]
Expand Down Expand Up @@ -429,7 +439,48 @@ def __init__(
] = {}
self.reasoning_item_ids: list[str] = []
self.current_reasoning_item_id: Optional[str] = None
self.message_item_ids: list[str] = []
self.current_message_item_id: Optional[str] = None
self.functions_python_as_builtin = functions_python_as_builtin
self.user_defined_function_names = {
name
for tool in (request_body.tools or [])
for name in [getattr(tool, "name", None)]
if getattr(tool, "type", None) == "function" and name
}

def _resolve_browser_recipient(
self, recipient: Optional[str]
) -> tuple[Optional[str], bool]:
if not self.use_browser_tool or not recipient:
return (None, False)

if recipient.startswith("browser."):
return (recipient, False)

if recipient.startswith("functions."):
potential = recipient[len("functions.") :]
if (
potential in self.BROWSER_RESERVED_FUNCTIONS
and potential not in self.user_defined_function_names
):
return (potential, True)

return (None, False)

def _ensure_message_item_id(self) -> str:
if self.current_message_item_id is None:
message_id = f"item_{uuid.uuid4().hex}"
self.current_message_item_id = message_id
self.message_item_ids.append(message_id)
return self.current_message_item_id

def _ensure_reasoning_item_id(self) -> str:
if self.current_reasoning_item_id is None:
reasoning_id = f"rs_{uuid.uuid4().hex}"
self.current_reasoning_item_id = reasoning_id
self.reasoning_item_ids.append(reasoning_id)
return self.current_reasoning_item_id

def _send_event(self, event: ResponseEvent):
event.sequence_number = self.sequence_number
Expand All @@ -455,6 +506,7 @@ async def run(self):
python_call_ids=self.python_call_ids,
python_call_outputs=getattr(self, "python_call_outputs", None),
reasoning_ids=self.reasoning_item_ids,
message_ids=self.message_item_ids,
treat_functions_python_as_builtin=self.functions_python_as_builtin,
)
initial_response.status = "in_progress"
Expand Down Expand Up @@ -508,8 +560,11 @@ async def run(self):
previous_item = self.parser.messages[-1]
if previous_item.recipient is not None:
recipient = previous_item.recipient
browser_recipient, _ = self._resolve_browser_recipient(
recipient
)
if (
not recipient.startswith("browser.")
browser_recipient is None
and not (
recipient == "python"
or (
Expand Down Expand Up @@ -542,28 +597,34 @@ async def run(self):
),
)
)
if previous_item.channel == "analysis":
reasoning_id = self.current_reasoning_item_id
if reasoning_id is None:
reasoning_id = f"rs_{uuid.uuid4().hex}"
self.reasoning_item_ids.append(reasoning_id)
self.current_reasoning_item_id = reasoning_id
if (
previous_item.channel == "analysis"
and previous_item.recipient is None
):
reasoning_id = (
self.current_reasoning_item_id
if self.current_reasoning_item_id is not None
else self._ensure_reasoning_item_id()
)
reasoning_text = previous_item.content[0].text
yield self._send_event(
ResponseReasoningTextDone(
type="response.reasoning_text.done",
output_index=current_output_index,
content_index=current_content_index,
text=previous_item.content[0].text,
item_id=reasoning_id,
text=reasoning_text,
)
)
yield self._send_event(
ResponseContentPartDone(
type="response.content_part.done",
output_index=current_output_index,
content_index=current_content_index,
item_id=reasoning_id,
part=ReasoningTextContentItem(
type="reasoning_text",
text=previous_item.content[0].text,
text=reasoning_text,
),
)
)
Expand All @@ -578,7 +639,7 @@ async def run(self):
content=[
ReasoningTextContentItem(
type="reasoning_text",
text=previous_item.content[0].text,
text=reasoning_text,
)
],
),
Expand All @@ -605,11 +666,17 @@ async def run(self):
text=normalized_text,
annotations=annotations,
)
message_id = (
self.current_message_item_id
if self.current_message_item_id is not None
else self._ensure_message_item_id()
)
yield self._send_event(
ResponseOutputTextDone(
type="response.output_text.done",
output_index=current_output_index,
content_index=current_content_index,
item_id=message_id,
text=normalized_text,
)
)
Expand All @@ -618,6 +685,7 @@ async def run(self):
type="response.content_part.done",
output_index=current_output_index,
content_index=current_content_index,
item_id=message_id,
part=text_content,
)
)
Expand All @@ -626,6 +694,7 @@ async def run(self):
type="response.output_item.done",
output_index=current_output_index,
item=Item(
id=message_id,
type="message",
role="assistant",
content=[text_content],
Expand All @@ -634,6 +703,7 @@ async def run(self):
)
current_annotations = []
current_output_text_content = ""
self.current_message_item_id = None

if (
self.parser.last_content_delta
Expand All @@ -642,18 +712,25 @@ async def run(self):
):
if not sent_output_item_added:
sent_output_item_added = True
message_id = self._ensure_message_item_id()
yield self._send_event(
ResponseOutputItemAdded(
type="response.output_item.added",
output_index=current_output_index,
item=Item(type="message", role="assistant", content=[]),
item=Item(
id=message_id,
type="message",
role="assistant",
content=[],
),
)
)
yield self._send_event(
ResponseContentPartAdded(
type="response.content_part.added",
output_index=current_output_index,
content_index=current_content_index,
item_id=message_id,
part=TextContentItem(type="output_text", text=""),
)
)
Expand Down Expand Up @@ -685,11 +762,13 @@ async def run(self):
for a in new_annotations:
current_annotations.append(a)
citation = UrlCitation(**a)
message_id = self._ensure_message_item_id()
yield self._send_event(
ResponseOutputTextAnnotationAdded(
type="response.output_text.annotation.added",
output_index=current_output_index,
content_index=current_content_index,
item_id=message_id,
annotation_index=len(current_annotations),
annotation=citation,
)
Expand All @@ -699,11 +778,13 @@ async def run(self):
should_send_output_text_delta = False

if should_send_output_text_delta:
message_id = self._ensure_message_item_id()
yield self._send_event(
ResponseOutputTextDelta(
type="response.output_text.delta",
output_index=current_output_index,
content_index=current_content_index,
item_id=message_id,
delta=output_delta_buffer,
)
)
Expand All @@ -717,9 +798,7 @@ async def run(self):
):
if not sent_output_item_added:
sent_output_item_added = True
reasoning_id = f"rs_{uuid.uuid4().hex}"
self.current_reasoning_item_id = reasoning_id
self.reasoning_item_ids.append(reasoning_id)
reasoning_id = self._ensure_reasoning_item_id()
yield self._send_event(
ResponseOutputItemAdded(
type="response.output_item.added",
Expand All @@ -737,16 +816,19 @@ async def run(self):
type="response.content_part.added",
output_index=current_output_index,
content_index=current_content_index,
item_id=reasoning_id,
part=ReasoningTextContentItem(
type="reasoning_text", text=""
),
)
)
reasoning_id = self._ensure_reasoning_item_id()
yield self._send_event(
ResponseReasoningTextDelta(
type="response.reasoning_text.delta",
output_index=current_output_index,
content_index=current_content_index,
item_id=reasoning_id,
delta=self.parser.last_content_delta,
)
)
Expand All @@ -763,14 +845,20 @@ async def run(self):
if next_tok in encoding.stop_tokens_for_assistant_actions():
if len(self.parser.messages) > 0:
last_message = self.parser.messages[-1]
if (
self.use_browser_tool
and last_message.recipient is not None
and last_message.recipient.startswith("browser.")
):
function_name = last_message.recipient[len("browser.") :]
browser_recipient, is_browser_fallback = (
self._resolve_browser_recipient(last_message.recipient)
)
if browser_recipient is not None and browser_tool is not None:
message_for_browser = (
last_message
if not is_browser_fallback
else last_message.with_recipient(browser_recipient)
)
function_name = browser_recipient[len("browser.") :]
action = None
parsed_args = browser_tool.process_arguments(last_message)
parsed_args = browser_tool.process_arguments(
message_for_browser
)
if function_name == "search":
action = WebSearchActionSearch(
type="search",
Expand Down Expand Up @@ -810,25 +898,27 @@ async def run(self):
),
)
)
yield self._send_event(
ResponseWebSearchCallInProgress(
type="response.web_search_call.in_progress",
output_index=current_output_index,
id=web_search_call_id,
)
yield self._send_event(
ResponseWebSearchCallInProgress(
type="response.web_search_call.in_progress",
output_index=current_output_index,
item_id=web_search_call_id,
)
)

async def run_tool():
results = []
async for msg in browser_tool.process(last_message):
async for msg in browser_tool.process(
message_for_browser
):
results.append(msg)
return results

yield self._send_event(
ResponseWebSearchCallSearching(
type="response.web_search_call.searching",
output_index=current_output_index,
id=web_search_call_id,
item_id=web_search_call_id,
)
)
result = await run_tool()
Expand All @@ -852,7 +942,7 @@ async def run_tool():
ResponseWebSearchCallCompleted(
type="response.web_search_call.completed",
output_index=current_output_index,
id=web_search_call_id,
item_id=web_search_call_id,
)
)
yield self._send_event(
Expand Down Expand Up @@ -1030,6 +1120,7 @@ async def run_python_tool():
python_call_ids=self.python_call_ids,
python_call_outputs=self.python_call_outputs,
reasoning_ids=self.reasoning_item_ids,
message_ids=self.message_item_ids,
treat_functions_python_as_builtin=self.functions_python_as_builtin,
)
if self.store_callback and self.request_body.store:
Expand Down
1 change: 1 addition & 0 deletions gpt_oss/responses_api/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ class ReasoningItem(BaseModel):


class Item(BaseModel):
id: Optional[str] = None
type: Optional[Literal["message"]] = "message"
role: Literal["user", "assistant", "system"]
content: Union[list[TextContentItem], str]
Expand Down
Loading