Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions src/app/endpoints/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
from llama_stack_client import LlamaStackClient
from llama_stack_client.types import UserMessage

from fastapi import APIRouter, Request
from fastapi import APIRouter

from pydantic import BaseModel

from configuration import configuration
from models.config import LLamaStackConfiguration
Expand All @@ -26,8 +28,12 @@
}


class LLMRequest(BaseModel):
query: str
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's keep models in the models/ directory.

Also, there's now a QueryRequest model that I think we should use instead of introducing a new one:

class QueryRequest(BaseModel):



@router.post("/query", responses=query_response)
def info_endpoint_handler(request: Request, query: str) -> QueryResponse:
def info_endpoint_handler(request: LLMRequest, query: str) -> QueryResponse:
llama_stack_config = configuration.llama_stack_configuration
logger.info("LLama stack config: %s", llama_stack_config)

Expand Down
302 changes: 302 additions & 0 deletions src/app/endpoints/streaming_query.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,302 @@
"""Handler for REST API call to provide answer to query."""

import json
import logging
import re
from typing import Any, Optional, Iterator, AsyncGenerator, Mapping

from llama_stack_client.lib.agents.agent import Agent
from llama_stack_client.lib.agents.event_logger import (
TurnStreamPrintableEvent,
TurnStreamEventPrinter,
)
from llama_stack_client import LlamaStackClient
from llama_stack_client.types import UserMessage
from llama_stack_client.types.shared.interleaved_content_item import TextContentItem

from fastapi import APIRouter
from fastapi.responses import StreamingResponse

from app.endpoints.query import get_llama_stack_client, LLMRequest
from configuration import configuration

logger = logging.getLogger("app.endpoints.handlers")
router = APIRouter(tags=["streaming_query"])


query_response: dict[int | str, dict[str, Any]] = {
200: {
"query": "User query",
"answer": "LLM answer",
},
}


@router.post("/streaming_query", responses=query_response)
def info_endpoint_handler(request: LLMRequest) -> StreamingResponse:
llama_stack_config = configuration.llama_stack_configuration
logger.info("LLama stack config: %s", llama_stack_config)

client = get_llama_stack_client(llama_stack_config)

# retrieve list of available models
models = client.models.list()

# select the first LLM
llm = next(m for m in models if m.model_type == "llm")
model_id = llm.identifier
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's keep compatibility with the /query endpoint, which supports passing a model/provider in the request


logger.info("Model: %s", model_id)

response = retrieve_response(client, model_id, request.query)

return StreamingResponse(
response_processing_wrapper(
request,
response,
)
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use the QueryResponse for this ? I don't see why introducing a new StreamingResponse model.



def retrieve_response(client: LlamaStackClient, model_id: str, prompt: str) -> str:

available_shields = [shield.identifier for shield in client.shields.list()]
if not available_shields:
logger.info("No available shields. Disabling safety")
else:
logger.info(f"Available shields found: {available_shields}")

available_vector_dbs = [
vector_db.identifier for vector_db in client.vector_dbs.list()
]
if not available_vector_dbs:
raise RuntimeError("No available vector DBs.")
vector_db_id = available_vector_dbs[0]

agent = Agent(
client,
model=model_id,
instructions="""You are a helpful assistant with access to the following tools.
When a tool is required to answer the user's query, respond only with <|tool_call|>
followed by a JSON list of tools used. If a tool does not exist in the provided
list of tools, notify the user that you do not have the ability to fulfill the request.
""",
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto, keeping compatibility with the /query endpoint which supports passing a system_prompt in the request

input_shields=available_shields if available_shields else [],
tools=[
{
"name": "builtin::rag/knowledge_search",
"args": {
"vector_db_ids": [vector_db_id],
# Defaults
"query_config": {
"chunk_size_in_tokens": 512,
"chunk_overlap_in_tokens": 0,
"chunk_template": "Result {index}\nContent: {chunk.content}\nMetadata: {metadata}\n",
},
},
}
],
)
session_id = agent.create_session("chat_session")
response = agent.create_turn(
messages=[UserMessage(role="user", content=prompt)],
session_id=session_id,
)
return response
# return str(response.output_message.content)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A lot of this code is duplicated with /query. I think we can refactor these two endpoints and have a common base of code. Otherwise we will endpoint creating a lot of inconsistencies between the two.



async def response_processing_wrapper(
request: LLMRequest,
generator: AsyncGenerator[Any, None],
) -> AsyncGenerator[str, None]:
"""Process the response from the generator and handle metadata and errors."""

idx = 0
logger = RAGEventLogger()
try:
for item in logger.log(generator):
yield build_yield_item(str(item), idx)
idx += 1
finally:
ref_docs = logger.printer.metadata_map
yield stream_end_event(
ref_docs,
)


def build_yield_item(item: str, idx: int) -> str:
return format_stream_data(
{
"event": "token",
"data": {"id": idx, "token": item},
}
)


def stream_end_event(ref_docs_metadata: Mapping[str, dict]):
ref_docs = []
for k, v in ref_docs_metadata.items():
ref_docs.append(
{
"doc_url": v["docs_url"],
"doc_title": v["title"], # todo
}
)
return format_stream_data(
{
"event": "end",
"data": {
"referenced_documents": ref_docs,
"truncated": False, # TODO
"input_tokens": 0, # TODO
"output_tokens": 0, # TODO
},
"available_quotas": 0, # TODO
}
)


def format_stream_data(d: dict) -> str:
"""Format outbound data in the Event Stream Format."""
data = json.dumps(d)
return f"data: {data}\n\n"


class TurnStreamPrintableEventEx(TurnStreamPrintableEvent):
def __str__(self) -> str:
if self.role is not None:
return f"\n\n`{self.role}>` {self.content}"
else:
return f"{self.content}"


class RAGTurnStreamEventPrinter(TurnStreamEventPrinter):
metadata_pattern = re.compile(r"\nMetadata: (\{.+\})\n")

def __init__(self):
super().__init__()
self.metadata_map = {}

def _yield_printable_events(
self,
chunk: Any,
previous_event_type: Optional[str] = None,
previous_step_type: Optional[str] = None,
) -> Iterator[TurnStreamPrintableEventEx]:
if hasattr(chunk, "error"):
yield TurnStreamPrintableEventEx(
role=None, content=chunk.error["message"], color="red"
)
return

event = chunk.event
event_type = event.payload.event_type

if event_type in {"turn_start", "turn_complete", "turn_awaiting_input"}:
# Currently not logging any turn related info
yield TurnStreamPrintableEventEx(
role=None, content="", end="", color="grey"
)
return

step_type = event.payload.step_type
# handle safety
if step_type == "shield_call" and event_type == "step_complete":
violation = event.payload.step_details.violation
if not violation:
yield TurnStreamPrintableEventEx(
role=step_type, content="No Violation", color="magenta"
)
else:
yield TurnStreamPrintableEventEx(
role=step_type,
content=f"{violation.metadata} {violation.user_message}",
color="red",
)

# handle inference
if step_type == "inference":
if event_type == "step_start":
yield TurnStreamPrintableEventEx(
role=step_type, content="", end="", color="yellow"
)
elif event_type == "step_progress":
if event.payload.delta.type == "tool_call":
if isinstance(event.payload.delta.tool_call, str):
yield TurnStreamPrintableEventEx(
role=None,
content=event.payload.delta.tool_call,
end="",
color="cyan",
)
elif event.payload.delta.type == "text":
yield TurnStreamPrintableEventEx(
role=None,
content=event.payload.delta.text,
end="",
color="yellow",
)
else:
# step complete
yield TurnStreamPrintableEventEx(role=None, content="")

# handle tool_execution
if step_type == "tool_execution" and event_type == "step_complete":
# Only print tool calls and responses at the step_complete event
details = event.payload.step_details
for t in details.tool_calls:
yield TurnStreamPrintableEventEx(
role=step_type,
content=f"Tool:{t.tool_name} Args:{t.arguments}",
color="green",
)

for r in details.tool_responses:
if r.tool_name == "query_from_memory":
inserted_context = super().interleaved_content_as_str(r.content)
content = f"fetched {len(inserted_context)} bytes from memory"

yield TurnStreamPrintableEventEx(
role=step_type,
content=content,
color="cyan",
)
else:
# Referenced documents support
if r.tool_name == "knowledge_search" and r.content:
summary = ""
for i, text_content_item in enumerate(r.content):
if isinstance(text_content_item, TextContentItem):
if i == 0:
summary = text_content_item.text
summary = summary[: summary.find("\n")]
matches = self.metadata_pattern.findall(
text_content_item.text
)
if matches:
for match in matches:
meta = json.loads(match.replace("'", '"'))
self.metadata_map[meta["document_id"]] = meta
yield TurnStreamPrintableEventEx(
role=step_type,
content=f"\nTool:{r.tool_name} Summary:{summary}\n",
color="green",
)
else:
yield TurnStreamPrintableEventEx(
role=step_type,
content=f"Tool:{r.tool_name} Response:{r.content}",
color="green",
)


class RAGEventLogger:
printer: RAGTurnStreamEventPrinter

def log(
self, event_generator: Iterator[Any]
) -> Iterator[TurnStreamPrintableEventEx]:
self.printer = RAGTurnStreamEventPrinter()
for chunk in event_generator:
yield from self.printer.yield_printable_events(chunk)
5 changes: 3 additions & 2 deletions src/app/routers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from fastapi import FastAPI

from app.endpoints import info, models, root, query, health, config
from app.endpoints import info, models, root, query, health, config, streaming_query


def include_routers(app: FastAPI) -> None:
Expand All @@ -15,5 +15,6 @@ def include_routers(app: FastAPI) -> None:
app.include_router(info.router, prefix="/v1")
app.include_router(models.router, prefix="/v1")
app.include_router(query.router, prefix="/v1")
app.include_router(health.router, prefix="/v1")
app.include_router(streaming_query.router, prefix="/v1")
app.include_router(health.router, prefix="")
app.include_router(config.router, prefix="/v1")