-
Notifications
You must be signed in to change notification settings - Fork 55
Do not create new session if conversation_id is provided #163
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -6,8 +6,10 @@ | |
| import os | ||
| from pathlib import Path | ||
| from typing import Any | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit: an empty line separating the built-in library to the 3rd party libraries |
||
| from llama_stack_client.lib.agents.agent import Agent | ||
|
|
||
| from cachetools import TTLCache # type: ignore | ||
|
|
||
| from llama_stack_client.lib.agents.agent import Agent | ||
| from llama_stack_client import APIConnectionError | ||
| from llama_stack_client import LlamaStackClient # type: ignore | ||
| from llama_stack_client.types import UserMessage # type: ignore | ||
|
|
@@ -32,6 +34,8 @@ | |
| logger = logging.getLogger("app.endpoints.handlers") | ||
| router = APIRouter(tags=["query"]) | ||
|
|
||
| # Global agent registry to persist agents across requests | ||
| _agent_cache: TTLCache[str, Agent] = TTLCache(maxsize=1000, ttl=3600) | ||
|
|
||
| query_response: dict[int | str, dict[str, Any]] = { | ||
| 200: { | ||
|
|
@@ -56,16 +60,33 @@ def is_transcripts_enabled() -> bool: | |
| return not configuration.user_data_collection_configuration.transcripts_disabled | ||
|
|
||
|
|
||
| def retrieve_conversation_id(query_request: QueryRequest) -> str: | ||
| """Retrieve conversation ID based on existing ID or on newly generated one.""" | ||
| conversation_id = query_request.conversation_id | ||
|
|
||
| # Generate a new conversation ID if not provided | ||
| if not conversation_id: | ||
| conversation_id = get_suid() | ||
| logger.info("Generated new conversation ID: %s", conversation_id) | ||
|
|
||
| return conversation_id | ||
| def get_agent( | ||
| client: LlamaStackClient, | ||
| model_id: str, | ||
| system_prompt: str, | ||
| available_shields: list[str], | ||
| conversation_id: str | None, | ||
| ) -> tuple[Agent, str]: | ||
| """Get existing agent or create a new one with session persistence.""" | ||
| if conversation_id is not None: | ||
| agent = _agent_cache.get(conversation_id) | ||
| if agent: | ||
| logger.debug("Reusing existing agent with key: %s", conversation_id) | ||
| return agent, conversation_id | ||
|
|
||
| logger.debug("Creating new agent") | ||
| # TODO(lucasagomes): move to ReActAgent | ||
| agent = Agent( | ||
| client, | ||
| model=model_id, | ||
| instructions=system_prompt, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This represents an interesting scenario I missed earlier.
The Meaning I think we should probably remove
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. TBH, it seems strange that the system prompt is passed in the request.
I think that we should keep this and in case the Unsure whether that is the right thing to do here, perhaps it's better to differ this and solve it in another PR (along with #123. |
||
| input_shields=available_shields if available_shields else [], | ||
| tools=[mcp.name for mcp in configuration.mcp_servers], | ||
| enable_session_persistence=True, | ||
| ) | ||
| conversation_id = agent.create_session(get_suid()) | ||
| _agent_cache[conversation_id] = agent | ||
| return agent, conversation_id | ||
|
|
||
|
|
||
| @router.post("/query", responses=query_response) | ||
|
|
@@ -83,8 +104,9 @@ def query_endpoint_handler( | |
| # try to get Llama Stack client | ||
| client = get_llama_stack_client(llama_stack_config) | ||
| model_id = select_model_id(client.models.list(), query_request) | ||
| conversation_id = retrieve_conversation_id(query_request) | ||
| response = retrieve_response(client, model_id, query_request, auth) | ||
| response, conversation_id = retrieve_response( | ||
| client, model_id, query_request, auth | ||
| ) | ||
|
|
||
| if not is_transcripts_enabled(): | ||
| logger.debug("Transcript collection is disabled in the configuration") | ||
|
|
@@ -163,7 +185,7 @@ def retrieve_response( | |
| model_id: str, | ||
| query_request: QueryRequest, | ||
| token: str, | ||
| ) -> str: | ||
| ) -> tuple[str, str]: | ||
| """Retrieve response from LLMs and agents.""" | ||
| available_shields = [shield.identifier for shield in client.shields.list()] | ||
| if not available_shields: | ||
|
|
@@ -184,21 +206,28 @@ def retrieve_response( | |
| if query_request.attachments: | ||
| validate_attachments_metadata(query_request.attachments) | ||
|
|
||
| # Build mcp_headers config dynamically for all MCP servers | ||
| # this will allow the agent to pass the user token to the MCP servers | ||
| agent, conversation_id = get_agent( | ||
| client, | ||
| model_id, | ||
| system_prompt, | ||
| available_shields, | ||
| query_request.conversation_id, | ||
| ) | ||
|
|
||
| mcp_headers = {} | ||
| if token: | ||
| for mcp_server in configuration.mcp_servers: | ||
| mcp_headers[mcp_server.url] = { | ||
| "Authorization": f"Bearer {token}", | ||
| } | ||
| # TODO(lucasagomes): move to ReActAgent | ||
| agent = Agent( | ||
| client, | ||
| model=model_id, | ||
| instructions=system_prompt, | ||
| input_shields=available_shields if available_shields else [], | ||
| tools=[mcp.name for mcp in configuration.mcp_servers], | ||
|
|
||
| vector_db_ids = [vector_db.identifier for vector_db in client.vector_dbs.list()] | ||
| response = agent.create_turn( | ||
| messages=[UserMessage(role="user", content=query_request.query)], | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ... and change this to be: This is what |
||
| session_id=conversation_id, | ||
| documents=query_request.get_documents(), | ||
| stream=False, | ||
| toolgroups=get_rag_toolgroups(vector_db_ids), | ||
| extra_headers={ | ||
| "X-LlamaStack-Provider-Data": json.dumps( | ||
| { | ||
|
|
@@ -207,17 +236,8 @@ def retrieve_response( | |
| ), | ||
| }, | ||
| ) | ||
| session_id = agent.create_session("chat_session") | ||
| logger.debug("Session ID: %s", session_id) | ||
| vector_db_ids = [vector_db.identifier for vector_db in client.vector_dbs.list()] | ||
| response = agent.create_turn( | ||
| messages=[UserMessage(role="user", content=query_request.query)], | ||
| session_id=session_id, | ||
| documents=query_request.get_documents(), | ||
| stream=False, | ||
| toolgroups=get_rag_toolgroups(vector_db_ids), | ||
| ) | ||
| return str(response.output_message.content) # type: ignore[union-attr] | ||
|
|
||
| return str(response.output_message.content), conversation_id # type: ignore[union-attr] | ||
|
|
||
|
|
||
| def validate_attachments_metadata(attachments: list[Attachment]) -> None: | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,6 +4,8 @@ | |
| import logging | ||
| from typing import Any, AsyncIterator | ||
|
|
||
| from cachetools import TTLCache # type: ignore | ||
|
|
||
| from llama_stack_client import APIConnectionError | ||
| from llama_stack_client.lib.agents.agent import AsyncAgent # type: ignore | ||
| from llama_stack_client import AsyncLlamaStackClient # type: ignore | ||
|
|
@@ -19,12 +21,12 @@ | |
| from utils.auth import auth_dependency | ||
| from utils.endpoints import check_configuration_loaded | ||
| from utils.common import retrieve_user_id | ||
| from utils.suid import get_suid | ||
|
|
||
|
|
||
| from app.endpoints.query import ( | ||
| get_rag_toolgroups, | ||
| is_transcripts_enabled, | ||
| retrieve_conversation_id, | ||
| store_transcript, | ||
| select_model_id, | ||
| validate_attachments_metadata, | ||
|
|
@@ -33,6 +35,37 @@ | |
| logger = logging.getLogger("app.endpoints.handlers") | ||
| router = APIRouter(tags=["streaming_query"]) | ||
|
|
||
| # Global agent registry to persist agents across requests | ||
| _agent_cache: TTLCache[str, AsyncAgent] = TTLCache(maxsize=1000, ttl=3600) | ||
|
|
||
|
|
||
| async def get_agent( | ||
| client: AsyncLlamaStackClient, | ||
| model_id: str, | ||
| system_prompt: str, | ||
| available_shields: list[str], | ||
| conversation_id: str | None, | ||
| ) -> tuple[AsyncAgent, str]: | ||
| """Get existing agent or create a new one with session persistence.""" | ||
| if conversation_id is not None: | ||
| agent = _agent_cache.get(conversation_id) | ||
| if agent: | ||
| logger.debug("Reusing existing agent with key: %s", conversation_id) | ||
| return agent, conversation_id | ||
|
|
||
| logger.debug("Creating new agent") | ||
| agent = AsyncAgent( | ||
| client, # type: ignore[arg-type] | ||
| model=model_id, | ||
| instructions=system_prompt, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Likewise; remove this line. |
||
| input_shields=available_shields if available_shields else [], | ||
| tools=[], # mcp config ? | ||
| enable_session_persistence=True, | ||
| ) | ||
| conversation_id = await agent.create_session(get_suid()) | ||
| _agent_cache[conversation_id] = agent | ||
| return agent, conversation_id | ||
|
|
||
|
|
||
| def format_stream_data(d: dict) -> str: | ||
| """Format outbound data in the Event Stream Format.""" | ||
|
|
@@ -139,8 +172,9 @@ async def streaming_query_endpoint_handler( | |
| # try to get Llama Stack client | ||
| client = await get_async_llama_stack_client(llama_stack_config) | ||
| model_id = select_model_id(await client.models.list(), query_request) | ||
| conversation_id = retrieve_conversation_id(query_request) | ||
| response = await retrieve_response(client, model_id, query_request) | ||
| response, conversation_id = await retrieve_response( | ||
| client, model_id, query_request | ||
| ) | ||
|
|
||
| async def response_generator(turn_response: Any) -> AsyncIterator[str]: | ||
| """Generate SSE formatted streaming response.""" | ||
|
|
@@ -191,7 +225,7 @@ async def response_generator(turn_response: Any) -> AsyncIterator[str]: | |
|
|
||
| async def retrieve_response( | ||
| client: AsyncLlamaStackClient, model_id: str, query_request: QueryRequest | ||
| ) -> Any: | ||
| ) -> tuple[Any, str]: | ||
| """Retrieve response from LLMs and agents.""" | ||
| available_shields = [shield.identifier for shield in await client.shields.list()] | ||
| if not available_shields: | ||
|
|
@@ -212,24 +246,24 @@ async def retrieve_response( | |
| if query_request.attachments: | ||
| validate_attachments_metadata(query_request.attachments) | ||
|
|
||
| agent = AsyncAgent( | ||
| client, # type: ignore[arg-type] | ||
| model=model_id, | ||
| instructions=system_prompt, | ||
| input_shields=available_shields if available_shields else [], | ||
| tools=[], | ||
| agent, conversation_id = await get_agent( | ||
| client, | ||
| model_id, | ||
| system_prompt, | ||
| available_shields, | ||
| query_request.conversation_id, | ||
| ) | ||
| session_id = await agent.create_session("chat_session") | ||
| logger.debug("Session ID: %s", session_id) | ||
|
|
||
| logger.debug("Session ID: %s", conversation_id) | ||
| vector_db_ids = [ | ||
| vector_db.identifier for vector_db in await client.vector_dbs.list() | ||
| ] | ||
| response = await agent.create_turn( | ||
| messages=[UserMessage(role="user", content=query_request.query)], | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Likewise, add a |
||
| session_id=session_id, | ||
| session_id=conversation_id, | ||
| documents=query_request.get_documents(), | ||
| stream=True, | ||
| toolgroups=get_rag_toolgroups(vector_db_ids), | ||
| ) | ||
|
|
||
| return response | ||
| return response, conversation_id | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure if we should be using expiringdict, the project looks abandoned [0] the last release was in 2022 and the code repository is marked as not active [1].
May I suggest using cachetools for this ? Very similar syntax and well maintained project [2]
[0] https://pypi.org/project/expiringdict/#history
[1] https://app.travis-ci.com/github/mailgun/expiringdict/
[2] https://pypi.org/project/cachetools/#history
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
great suggestion. Switched to
cachetools.