Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ dependencies = [
"uvicorn>=0.34.3",
"llama-stack>=0.2.13",
"rich>=14.0.0",
"expiringdict>=1.2.2",
"cachetools>=6.1.0",
]

[tool.pdm]
Expand Down
91 changes: 56 additions & 35 deletions src/app/endpoints/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@
import os
from pathlib import Path
from typing import Any
from llama_stack_client.lib.agents.agent import Agent

from cachetools import TTLCache # type: ignore

from llama_stack_client.lib.agents.agent import Agent
from llama_stack_client import APIConnectionError
from llama_stack_client import LlamaStackClient # type: ignore
from llama_stack_client.types import UserMessage # type: ignore
Expand All @@ -32,6 +34,8 @@
logger = logging.getLogger("app.endpoints.handlers")
router = APIRouter(tags=["query"])

# Global agent registry to persist agents across requests
_agent_cache: TTLCache[str, Agent] = TTLCache(maxsize=1000, ttl=3600)

query_response: dict[int | str, dict[str, Any]] = {
200: {
Expand All @@ -56,16 +60,33 @@ def is_transcripts_enabled() -> bool:
return not configuration.user_data_collection_configuration.transcripts_disabled


def retrieve_conversation_id(query_request: QueryRequest) -> str:
"""Retrieve conversation ID based on existing ID or on newly generated one."""
conversation_id = query_request.conversation_id

# Generate a new conversation ID if not provided
if not conversation_id:
conversation_id = get_suid()
logger.info("Generated new conversation ID: %s", conversation_id)

return conversation_id
def get_agent(
client: LlamaStackClient,
model_id: str,
system_prompt: str,
available_shields: list[str],
conversation_id: str | None,
) -> tuple[Agent, str]:
"""Get existing agent or create a new one with session persistence."""
if conversation_id is not None:
agent = _agent_cache.get(conversation_id)
if agent:
logger.debug("Reusing existing agent with key: %s", conversation_id)
return agent, conversation_id

logger.debug("Creating new agent")
# TODO(lucasagomes): move to ReActAgent
agent = Agent(
client,
model=model_id,
instructions=system_prompt,
input_shields=available_shields if available_shields else [],
tools=[mcp.name for mcp in configuration.mcp_servers],
enable_session_persistence=True,
)
conversation_id = agent.create_session(get_suid())
_agent_cache[conversation_id] = agent
return agent, conversation_id


@router.post("/query", responses=query_response)
Expand All @@ -83,8 +104,9 @@ def query_endpoint_handler(
# try to get Llama Stack client
client = get_llama_stack_client(llama_stack_config)
model_id = select_model_id(client.models.list(), query_request)
conversation_id = retrieve_conversation_id(query_request)
response = retrieve_response(client, model_id, query_request, auth)
response, conversation_id = retrieve_response(
client, model_id, query_request, auth
)

if not is_transcripts_enabled():
logger.debug("Transcript collection is disabled in the configuration")
Expand Down Expand Up @@ -163,7 +185,7 @@ def retrieve_response(
model_id: str,
query_request: QueryRequest,
token: str,
) -> str:
) -> tuple[str, str]:
"""Retrieve response from LLMs and agents."""
available_shields = [shield.identifier for shield in client.shields.list()]
if not available_shields:
Expand All @@ -184,40 +206,39 @@ def retrieve_response(
if query_request.attachments:
validate_attachments_metadata(query_request.attachments)

# Build mcp_headers config dynamically for all MCP servers
# this will allow the agent to pass the user token to the MCP servers
agent, conversation_id = get_agent(
client,
model_id,
system_prompt,
available_shields,
query_request.conversation_id,
)

mcp_headers = {}
if token:
for mcp_server in configuration.mcp_servers:
mcp_headers[mcp_server.url] = {
"Authorization": f"Bearer {token}",
}
# TODO(lucasagomes): move to ReActAgent
agent = Agent(
client,
model=model_id,
instructions=system_prompt,
input_shields=available_shields if available_shields else [],
tools=[mcp.name for mcp in configuration.mcp_servers],
extra_headers={
"X-LlamaStack-Provider-Data": json.dumps(
{
"mcp_headers": mcp_headers,
}
),
},
)
session_id = agent.create_session("chat_session")
logger.debug("Session ID: %s", session_id)

agent.extra_headers = {
"X-LlamaStack-Provider-Data": json.dumps(
{
"mcp_headers": mcp_headers,
}
),
}

vector_db_ids = [vector_db.identifier for vector_db in client.vector_dbs.list()]
response = agent.create_turn(
messages=[UserMessage(role="user", content=query_request.query)],
session_id=session_id,
session_id=conversation_id,
documents=query_request.get_documents(),
stream=False,
toolgroups=get_rag_toolgroups(vector_db_ids),
)
return str(response.output_message.content) # type: ignore[union-attr]

return str(response.output_message.content), conversation_id # type: ignore[union-attr]


def validate_attachments_metadata(attachments: list[Attachment]) -> None:
Expand Down
82 changes: 67 additions & 15 deletions src/app/endpoints/streaming_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import logging
from typing import Any, AsyncIterator

from cachetools import TTLCache # type: ignore

from llama_stack_client import APIConnectionError
from llama_stack_client.lib.agents.agent import AsyncAgent # type: ignore
from llama_stack_client import AsyncLlamaStackClient # type: ignore
Expand All @@ -19,12 +21,12 @@
from utils.auth import auth_dependency
from utils.endpoints import check_configuration_loaded
from utils.common import retrieve_user_id
from utils.suid import get_suid


from app.endpoints.query import (
get_rag_toolgroups,
is_transcripts_enabled,
retrieve_conversation_id,
store_transcript,
select_model_id,
validate_attachments_metadata,
Expand All @@ -33,6 +35,37 @@
logger = logging.getLogger("app.endpoints.handlers")
router = APIRouter(tags=["streaming_query"])

# Global agent registry to persist agents across requests
_agent_cache: TTLCache[str, AsyncAgent] = TTLCache(maxsize=1000, ttl=3600)


async def get_agent(
client: AsyncLlamaStackClient,
model_id: str,
system_prompt: str,
available_shields: list[str],
conversation_id: str | None,
) -> tuple[AsyncAgent, str]:
"""Get existing agent or create a new one with session persistence."""
if conversation_id is not None:
agent = _agent_cache.get(conversation_id)
if agent:
logger.debug("Reusing existing agent with key: %s", conversation_id)
return agent, conversation_id

logger.debug("Creating new agent")
agent = AsyncAgent(
client, # type: ignore[arg-type]
model=model_id,
instructions=system_prompt,
input_shields=available_shields if available_shields else [],
tools=[mcp.name for mcp in configuration.mcp_servers],
enable_session_persistence=True,
)
conversation_id = await agent.create_session(get_suid())
_agent_cache[conversation_id] = agent
return agent, conversation_id


def format_stream_data(d: dict) -> str:
"""Format outbound data in the Event Stream Format."""
Expand Down Expand Up @@ -139,8 +172,9 @@ async def streaming_query_endpoint_handler(
# try to get Llama Stack client
client = await get_async_llama_stack_client(llama_stack_config)
model_id = select_model_id(await client.models.list(), query_request)
conversation_id = retrieve_conversation_id(query_request)
response = await retrieve_response(client, model_id, query_request)
response, conversation_id = await retrieve_response(
client, model_id, query_request, auth
)

async def response_generator(turn_response: Any) -> AsyncIterator[str]:
"""Generate SSE formatted streaming response."""
Expand Down Expand Up @@ -190,8 +224,11 @@ async def response_generator(turn_response: Any) -> AsyncIterator[str]:


async def retrieve_response(
client: AsyncLlamaStackClient, model_id: str, query_request: QueryRequest
) -> Any:
client: AsyncLlamaStackClient,
model_id: str,
query_request: QueryRequest,
token: str,
) -> tuple[Any, str]:
"""Retrieve response from LLMs and agents."""
available_shields = [shield.identifier for shield in await client.shields.list()]
if not available_shields:
Expand All @@ -212,24 +249,39 @@ async def retrieve_response(
if query_request.attachments:
validate_attachments_metadata(query_request.attachments)

agent = AsyncAgent(
client, # type: ignore[arg-type]
model=model_id,
instructions=system_prompt,
input_shields=available_shields if available_shields else [],
tools=[],
agent, conversation_id = await get_agent(
client,
model_id,
system_prompt,
available_shields,
query_request.conversation_id,
)
session_id = await agent.create_session("chat_session")
logger.debug("Session ID: %s", session_id)

mcp_headers = {}
if token:
for mcp_server in configuration.mcp_servers:
mcp_headers[mcp_server.url] = {
"Authorization": f"Bearer {token}",
}

agent.extra_headers = {
"X-LlamaStack-Provider-Data": json.dumps(
{
"mcp_headers": mcp_headers,
}
),
}

logger.debug("Session ID: %s", conversation_id)
vector_db_ids = [
vector_db.identifier for vector_db in await client.vector_dbs.list()
]
response = await agent.create_turn(
messages=[UserMessage(role="user", content=query_request.query)],
session_id=session_id,
session_id=conversation_id,
documents=query_request.get_documents(),
stream=True,
toolgroups=get_rag_toolgroups(vector_db_ids),
)

return response
return response, conversation_id
Loading