-
Notifications
You must be signed in to change notification settings - Fork 2.8k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(assistants): add support for streaming (#1233)
See the reference docs for more information: https://platform.openai.com/docs/api-reference/assistants-streaming We've also improved some of the names for the types in the assistants beta, non exhaustive list: - `CodeToolCall` -> `CodeInterpreterToolCall` - `MessageContentImageFile` -> `ImageFileContentBlock` - `MessageContentText` -> `TextContentBlock` - `ThreadMessage` -> `Message` - `ThreadMessageDeleted` -> `MessageDeleted`
- Loading branch information
1 parent
ec104bf
commit 17635dc
Showing
75 changed files
with
4,443 additions
and
485 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
import openai | ||
|
||
# gets API Key from environment variable OPENAI_API_KEY | ||
client = openai.OpenAI() | ||
|
||
assistant = client.beta.assistants.create( | ||
name="Math Tutor", | ||
instructions="You are a personal math tutor. Write and run code to answer math questions.", | ||
tools=[{"type": "code_interpreter"}], | ||
model="gpt-4-1106-preview", | ||
) | ||
|
||
thread = client.beta.threads.create() | ||
|
||
message = client.beta.threads.messages.create( | ||
thread_id=thread.id, | ||
role="user", | ||
content="I need to solve the equation `3x + 11 = 14`. Can you help me?", | ||
) | ||
|
||
print("starting run stream") | ||
|
||
stream = client.beta.threads.runs.create( | ||
thread_id=thread.id, | ||
assistant_id=assistant.id, | ||
instructions="Please address the user as Jane Doe. The user has a premium account.", | ||
stream=True, | ||
) | ||
|
||
for event in stream: | ||
print(event.model_dump_json(indent=2, exclude_unset=True)) | ||
|
||
client.beta.assistants.delete(assistant.id) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
from __future__ import annotations | ||
|
||
from typing_extensions import override | ||
|
||
import openai | ||
from openai import AssistantEventHandler | ||
from openai.types.beta import AssistantStreamEvent | ||
from openai.types.beta.threads import Text, TextDelta | ||
from openai.types.beta.threads.runs import RunStep, RunStepDelta | ||
|
||
|
||
class EventHandler(AssistantEventHandler): | ||
@override | ||
def on_event(self, event: AssistantStreamEvent) -> None: | ||
if event.event == "thread.run.step.created": | ||
details = event.data.step_details | ||
if details.type == "tool_calls": | ||
print("Generating code to interpret:\n\n```py") | ||
elif event.event == "thread.message.created": | ||
print("\nResponse:\n") | ||
|
||
@override | ||
def on_text_delta(self, delta: TextDelta, snapshot: Text) -> None: | ||
print(delta.value, end="", flush=True) | ||
|
||
@override | ||
def on_run_step_done(self, run_step: RunStep) -> None: | ||
details = run_step.step_details | ||
if details.type == "tool_calls": | ||
for tool in details.tool_calls: | ||
if tool.type == "code_interpreter": | ||
print("\n```\nExecuting code...") | ||
|
||
@override | ||
def on_run_step_delta(self, delta: RunStepDelta, snapshot: RunStep) -> None: | ||
details = delta.step_details | ||
if details is not None and details.type == "tool_calls": | ||
for tool in details.tool_calls or []: | ||
if tool.type == "code_interpreter" and tool.code_interpreter and tool.code_interpreter.input: | ||
print(tool.code_interpreter.input, end="", flush=True) | ||
|
||
|
||
def main() -> None: | ||
client = openai.OpenAI() | ||
|
||
assistant = client.beta.assistants.create( | ||
name="Math Tutor", | ||
instructions="You are a personal math tutor. Write and run code to answer math questions.", | ||
tools=[{"type": "code_interpreter"}], | ||
model="gpt-4-1106-preview", | ||
) | ||
|
||
try: | ||
question = "I need to solve the equation `3x + 11 = 14`. Can you help me?" | ||
|
||
thread = client.beta.threads.create( | ||
messages=[ | ||
{ | ||
"role": "user", | ||
"content": question, | ||
}, | ||
] | ||
) | ||
print(f"Question: {question}\n") | ||
|
||
with client.beta.threads.runs.create_and_stream( | ||
thread_id=thread.id, | ||
assistant_id=assistant.id, | ||
instructions="Please address the user as Jane Doe. The user has a premium account.", | ||
event_handler=EventHandler(), | ||
) as stream: | ||
stream.until_done() | ||
print() | ||
finally: | ||
client.beta.assistants.delete(assistant.id) | ||
|
||
|
||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.