-
Notifications
You must be signed in to change notification settings - Fork 4.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
fix for concurrent OpenAIAgent streaming issue + notebook that was ab…
…le to reproduce the issue (#7164)
- Loading branch information
1 parent
df3fba6
commit cad264f
Showing
3 changed files
with
232 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,226 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# Test for issues with concurrent chats" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# setup code\n", | ||
"import dotenv\n", | ||
"\n", | ||
"dotenv.load_dotenv(\"../../.env\")\n", | ||
"\n", | ||
"import nest_asyncio\n", | ||
"\n", | ||
"nest_asyncio.apply()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Test Single chat at a time" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"final response_str: Hello! How can I assist you today?\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"from llama_index.llms import OpenAI, ChatMessage, MessageRole, ChatResponseAsyncGen\n", | ||
"\n", | ||
"# Necessary to use the latest OpenAI models that support function calling API\n", | ||
"async def run_chat():\n", | ||
" llm = OpenAI(model=\"gpt-3.5-turbo-0613\")\n", | ||
"\n", | ||
" streaming_chat_response: ChatResponseAsyncGen = await llm.astream_chat(\n", | ||
" [ChatMessage(content=\"Hi\", role=MessageRole.USER)]\n", | ||
" )\n", | ||
" response_str = \"\"\n", | ||
" async for chat_response in streaming_chat_response:\n", | ||
" response_str = chat_response.message.content\n", | ||
"\n", | ||
" print(\"final response_str: \" + response_str)\n", | ||
"\n", | ||
"\n", | ||
"await run_chat()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Test basic concurrency\n", | ||
"from typing import Callable\n", | ||
"from concurrent.futures import ThreadPoolExecutor\n", | ||
"import asyncio\n", | ||
"\n", | ||
"\n", | ||
"def run_in_event_loop(func: Callable):\n", | ||
" loop = asyncio.new_event_loop()\n", | ||
" asyncio.set_event_loop(loop)\n", | ||
" loop.run_until_complete(func())\n", | ||
" loop.close()\n", | ||
"\n", | ||
"\n", | ||
"async def run_task_n_times(task: Callable, n: int):\n", | ||
" def thread_task():\n", | ||
" run_in_event_loop(task)\n", | ||
"\n", | ||
" # Run two calls to run_chat concurrently and wait for them to finish\n", | ||
" with ThreadPoolExecutor(max_workers=5) as executor:\n", | ||
" futures = [executor.submit(thread_task) for _ in range(n)]\n", | ||
" for future in futures:\n", | ||
" future.result()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 4, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"final response_str: Hello! How can I assist you today?\n", | ||
"final response_str: Hello! How can I assist you today?\n", | ||
"final response_str: Hello! How can I assist you today?\n", | ||
"final response_str: Hello! How can I assist you today?\n", | ||
"final response_str: Hello! How can I assist you today?\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"await run_task_n_times(run_chat, 5)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Testing with `OpenAIAgent` based chat engine\n", | ||
"\n", | ||
"The above output looks fine for concurrent cases. Indicates there is no issue with concurrent chats coming from the LLM layer.\n", | ||
"Now will try from the other end of the stack with a full `OpenAIAgent` chat_engine." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 5, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from llama_index.agent import OpenAIAgent\n", | ||
"from llama_index.llms import OpenAI\n", | ||
"from llama_index.agent.openai_agent import StreamingAgentChatResponse\n", | ||
"\n", | ||
"\n", | ||
"async def run_chat():\n", | ||
" # initialize llm\n", | ||
" llm = OpenAI(model=\"gpt-3.5-turbo-0613\")\n", | ||
"\n", | ||
" # initialize openai agent\n", | ||
" agent = OpenAIAgent.from_tools([], llm=llm, verbose=True)\n", | ||
" streaming_chat_response: StreamingAgentChatResponse = await agent.astream_chat(\"Hi\")\n", | ||
" response_str = \"\"\n", | ||
" for text in streaming_chat_response.response_gen:\n", | ||
" response_str += text\n", | ||
" print(\"final response_str: \" + response_str)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Sanity check that the `OpenAIAgent` chat engine works on its own (no concurrent chats)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 6, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"final response_str: Hello! How can I assist you today?\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"await run_chat()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Ok now check for concurrency" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 7, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"final response_str: Hello! How can I assist you today?\n", | ||
"final response_str: Hello! How can I assist you today?\n", | ||
"final response_str: Hello! How can I assist you today?\n", | ||
"final response_str: Hello! How can I assist you today?\n", | ||
"final response_str: Hello! How can I assist you today?\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"await run_task_n_times(run_chat, 5)" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": ".venv", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.10.4" | ||
}, | ||
"orig_nbformat": 4 | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters