Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ dependencies = [
# Used by Llama Stack version checker
"semver<4.0.0",
# Used by authorization resolvers
"jsonpath-ng>=1.6.1",
"jsonpath-ng>=1.6.1"
]


Expand Down Expand Up @@ -176,6 +176,7 @@ addopts = [

[tool.pylint.main]
source-roots = "src"
ignore = ["query.py"]

[build-system]
requires = ["pdm-backend"]
Expand Down
27 changes: 23 additions & 4 deletions run.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ providers:
provider_id: meta-reference
provider_type: inline::meta-reference
inference:
- provider_id: sentence-transformers # Can be any embedding provider
provider_type: inline::sentence-transformers
config: {}
- provider_id: openai
provider_type: remote::openai
config:
Expand Down Expand Up @@ -99,14 +102,17 @@ providers:
- provider_id: model-context-protocol
provider_type: remote::model-context-protocol
config: {}
- provider_id: rag-runtime
provider_type: inline::rag-runtime
config: {}
vector_io:
- config:
kvstore:
db_path: .llama/distributions/ollama/faiss_store.db
db_path: .llama/distributions/ollama/faiss_store.db # Location of vector database
namespace: null
type: sqlite
provider_id: faiss
provider_type: inline::faiss
provider_type: inline::faiss # Or preferred vector DB
scoring_fns: []
server:
auth: null
Expand All @@ -117,10 +123,23 @@ server:
tls_certfile: null
tls_keyfile: null
shields: []
vector_dbs: []

vector_dbs:
- vector_db_id: my_knowledge_base
embedding_model: sentence-transformers/all-mpnet-base-v2
embedding_dimension: 768
provider_id: faiss
models:
- metadata:
embedding_dimension: 768 # Depends on chosen model
model_id: sentence-transformers/all-mpnet-base-v2 # Example embedding model
provider_id: sentence-transformers
provider_model_id: sentence-transformers/all-mpnet-base-v2 # Location of embedding model
model_type: embedding
- model_id: gpt-4-turbo
provider_id: openai
model_type: llm
provider_model_id: gpt-4-turbo

tool_groups:
- toolgroup_id: builtin::rag
provider_id: rag-runtime
50 changes: 48 additions & 2 deletions src/app/endpoints/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import Annotated, Any, Optional, cast

from fastapi import APIRouter, Depends, HTTPException, Request, status
from pydantic import AnyUrl
from llama_stack_client import (
APIConnectionError,
AsyncLlamaStackClient, # type: ignore
Expand Down Expand Up @@ -39,6 +40,7 @@
ForbiddenResponse,
QueryResponse,
ReferencedDocument,
ToolCall,
UnauthorizedResponse,
)
from utils.endpoints import (
Expand Down Expand Up @@ -248,6 +250,10 @@ async def query_endpoint_handler(
# Update metrics for the LLM call
metrics.llm_calls_total.labels(provider_id, model_id).inc()

# Convert RAG chunks to dictionary format once for reuse
logger.info("Processing RAG chunks...")
rag_chunks_dict = [chunk.model_dump() for chunk in summary.rag_chunks]

if not is_transcripts_enabled():
logger.debug("Transcript collection is disabled in the configuration")
else:
Expand All @@ -260,23 +266,63 @@ async def query_endpoint_handler(
query=query_request.query,
query_request=query_request,
summary=summary,
rag_chunks=[], # TODO(lucasagomes): implement rag_chunks
rag_chunks=rag_chunks_dict,
truncated=False, # TODO(lucasagomes): implement truncation as part of quota work
attachments=query_request.attachments or [],
)

logger.info("Persisting conversation details...")
persist_user_conversation_details(
user_id=user_id,
conversation_id=conversation_id,
model=model_id,
provider_id=provider_id,
)

return QueryResponse(
# Convert tool calls to response format
logger.info("Processing tool calls...")
tool_calls = [
ToolCall(
tool_name=tc.name,
arguments=(
tc.args if isinstance(tc.args, dict) else {"query": str(tc.args)}
),
result=(
{"response": tc.response}
if tc.response and tc.name != constants.DEFAULT_RAG_TOOL
else None
),
)
for tc in summary.tool_calls
]

logger.info("Extracting referenced documents...")
referenced_docs = []
doc_sources = set()
for chunk in summary.rag_chunks:
if chunk.source and chunk.source not in doc_sources:
doc_sources.add(chunk.source)
referenced_docs.append(
ReferencedDocument(
doc_url=(
AnyUrl(chunk.source)
if chunk.source.startswith("http")
else None
),
doc_title=chunk.source,
)
)

logger.info("Building final response...")
response = QueryResponse(
conversation_id=conversation_id,
response=summary.llm_response,
rag_chunks=summary.rag_chunks if summary.rag_chunks else [],
tool_calls=tool_calls if tool_calls else None,
referenced_documents=referenced_documents,
)
logger.info("Query processing completed successfully!")
return response

# connection to Llama Stack server
except APIConnectionError as e:
Expand Down
2 changes: 2 additions & 0 deletions src/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@
DEFAULT_JWT_UID_CLAIM = "user_id"
DEFAULT_JWT_USER_NAME_CLAIM = "username"

# default RAG tool value
DEFAULT_RAG_TOOL = "knowledge_search"
Comment on lines +55 to +56
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

🧩 Analysis chain

Good centralization; verify no hardcoded tool names remain.

Constant name/readability look good. Please confirm there are no remaining string literals "knowledge_search" outside constants.


🏁 Script executed:

#!/usr/bin/env bash
# Find stray literals (exclude config/examples)
rg -nP -S "(^|[^A-Za-z0-9_])knowledge_search([^A-Za-z0-9_]|$)" -g '!**/run.yaml' -g '!**/README*' -g '!**/docs/**'

Length of output: 148


🏁 Script executed:

#!/usr/bin/env bash
set -euo pipefail
echo "Searching repository for literal 'knowledge_search' using multiple fallbacks..."

# 1) ripgrep (robust flags)
if command -v rg >/dev/null 2>&1; then
  echo ":: rg --hidden --no-ignore -uu -F -S"
  rg --hidden --no-ignore -uu -F -n "knowledge_search" -S || echo "rg: no matches or search completed"
else
  echo "rg not available"
fi

# 2) git grep fallback
if git rev-parse --is-inside-work-tree >/dev/null 2>&1; then
  echo ":: git grep -n"
  git grep -n "knowledge_search" || echo "git grep: no matches"
else
  echo "Not a git repo (git grep skipped)"
fi

# 3) find + grep fallback (exclude common heavy dirs)
echo ":: find + grep (excluding .git and node_modules)"
find . -type f -not -path "./.git/*" -not -path "./node_modules/*" -not -path "./dist/*" -print0 \
  | xargs -0 grep -nH --line-number -I "knowledge_search" || echo "find+grep: no matches"

Length of output: 6544


Replace remaining 'knowledge_search' string literals with DEFAULT_RAG_TOOL

Search found literal occurrences in source and tests; update usages to import and use src.constants.DEFAULT_RAG_TOOL (use f"builtin::rag/{DEFAULT_RAG_TOOL}" where the namespaced form is required).

Locations:

  • src/models/responses.py:129
  • src/utils/types.py:87 (comment)
  • src/app/endpoints/query.py:642
  • src/app/endpoints/streaming_query.py:444
  • tests/unit/app/endpoints/test_query.py:1144
  • tests/unit/app/endpoints/test_streaming_query.py: multiple locations (71–85, 233, 242, 342, 349–350, 995–1023)


# PostgreSQL connection constants
# See: https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNECT-SSLMODE
Expand Down
57 changes: 47 additions & 10 deletions src/models/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,22 @@ class ModelsResponse(BaseModel):
)


class RAGChunk(BaseModel):
"""Model representing a RAG chunk used in the response."""

content: str = Field(description="The content of the chunk")
source: Optional[str] = Field(None, description="Source document or URL")
score: Optional[float] = Field(None, description="Relevance score")


class ToolCall(BaseModel):
"""Model representing a tool call made during response generation."""

tool_name: str = Field(description="Name of the tool called")
arguments: dict[str, Any] = Field(description="Arguments passed to the tool")
result: Optional[dict[str, Any]] = Field(None, description="Result from the tool")


class ReferencedDocument(BaseModel):
"""Model representing a document referenced in generating a response.

Expand All @@ -42,27 +58,27 @@ class ReferencedDocument(BaseModel):
doc_title: Title of the referenced doc.
"""

doc_url: AnyUrl = Field(description="URL of the referenced document")
doc_url: Optional[AnyUrl] = Field(
None, description="URL of the referenced document"
)

doc_title: str = Field(description="Title of the referenced document")


# TODO(lucasagomes): a lot of fields to add to QueryResponse. For now
# we are keeping it simple. The missing fields are:
# - truncated: Set to True if conversation history was truncated to be within context window.
# - input_tokens: Number of tokens sent to LLM
# - output_tokens: Number of tokens received from LLM
# - available_quotas: Quota available as measured by all configured quota limiters
# - tool_calls: List of tool requests.
# - tool_results: List of tool results.
# See LLMResponse in ols-service for more details.
class QueryResponse(BaseModel):
"""Model representing LLM response to a query.

Attributes:
conversation_id: The optional conversation ID (UUID).
response: The response.
rag_chunks: List of RAG chunks used to generate the response.
referenced_documents: The URLs and titles for the documents used to generate the response.
tool_calls: List of tool calls made during response generation.
TODO: truncated: Whether conversation history was truncated.
TODO: input_tokens: Number of tokens sent to LLM.
TODO: output_tokens: Number of tokens received from LLM.
TODO: available_quotas: Quota available as measured by all configured quota limiters
TODO: tool_results: List of tool results.
"""

conversation_id: Optional[str] = Field(
Expand All @@ -78,6 +94,13 @@ class QueryResponse(BaseModel):
],
)

rag_chunks: list[RAGChunk] = []

Comment on lines +97 to +98
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Fix mutable default list in API model.

rag_chunks uses [] which is shared across responses; switch to default_factory.

-    rag_chunks: list[RAGChunk] = []
+    rag_chunks: list[RAGChunk] = Field(
+        default_factory=list, description="List of RAG chunks used to generate the response"
+    )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
rag_chunks: list[RAGChunk] = []
rag_chunks: list[RAGChunk] = Field(
default_factory=list, description="List of RAG chunks used to generate the response"
)
🤖 Prompt for AI Agents
In src/models/responses.py around lines 93-94, the RAGChunk list is initialized
with a mutable default (rag_chunks: list[RAGChunk] = []) which is shared across
instances; replace it with a non-shared default_factory by importing Field from
pydantic (if not already imported) and change the declaration to use
Field(default_factory=list), i.e. rag_chunks: list[RAGChunk] =
Field(default_factory=list).

tool_calls: Optional[list[ToolCall]] = Field(
None,
description="List of tool calls made during response generation",
)

referenced_documents: list[ReferencedDocument] = Field(
default_factory=list,
description="List of documents referenced in generating the response",
Expand All @@ -99,6 +122,20 @@ class QueryResponse(BaseModel):
{
"conversation_id": "123e4567-e89b-12d3-a456-426614174000",
"response": "Operator Lifecycle Manager (OLM) helps users install...",
"rag_chunks": [
{
"content": "OLM is a component of the Operator Framework toolkit...",
"source": "kubernetes-docs/operators.md",
"score": 0.95,
}
],
"tool_calls": [
{
"tool_name": "knowledge_search",
"arguments": {"query": "operator lifecycle manager"},
"result": {"chunks_found": 5},
}
],
"referenced_documents": [
{
"doc_url": "https://docs.openshift.com/"
Expand Down
4 changes: 2 additions & 2 deletions src/utils/transcripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def store_transcript( # pylint: disable=too-many-arguments,too-many-positional-
query: str,
query_request: QueryRequest,
summary: TurnSummary,
rag_chunks: list[str],
rag_chunks: list[dict],
truncated: bool,
attachments: list[Attachment],
) -> None:
Expand All @@ -52,7 +52,7 @@ def store_transcript( # pylint: disable=too-many-arguments,too-many-positional-
query: The query (without attachments).
query_request: The request containing a query.
summary: Summary of the query/response turn.
rag_chunks: The list of `RagChunk` objects.
rag_chunks: The list of serialized `RAGChunk` dictionaries.
truncated: The flag indicating if the history was truncated.
attachments: The list of `Attachment` objects.
"""
Expand Down
63 changes: 60 additions & 3 deletions src/utils/types.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
"""Common types for the project."""

from typing import Any, Optional

import json
from llama_stack_client.lib.agents.event_logger import interleaved_content_as_str
from llama_stack_client.lib.agents.tool_parser import ToolParser
from llama_stack_client.types.shared.completion_message import CompletionMessage
from llama_stack_client.types.shared.tool_call import ToolCall
from llama_stack_client.types.tool_execution_step import ToolExecutionStep
from pydantic.main import BaseModel
from pydantic import BaseModel
from models.responses import RAGChunk
from constants import DEFAULT_RAG_TOOL


class Singleton(type):
Expand Down Expand Up @@ -61,18 +63,73 @@ class TurnSummary(BaseModel):

llm_response: str
tool_calls: list[ToolCallSummary]
rag_chunks: list[RAGChunk] = []

def append_tool_calls_from_llama(self, tec: ToolExecutionStep) -> None:
"""Append the tool calls from a llama tool execution step."""
calls_by_id = {tc.call_id: tc for tc in tec.tool_calls}
responses_by_id = {tc.call_id: tc for tc in tec.tool_responses}
for call_id, tc in calls_by_id.items():
resp = responses_by_id.get(call_id)
response_content = (
interleaved_content_as_str(resp.content) if resp else None
)

self.tool_calls.append(
ToolCallSummary(
id=call_id,
name=tc.tool_name,
args=tc.arguments,
response=interleaved_content_as_str(resp.content) if resp else None,
response=response_content,
)
)

# Extract RAG chunks from knowledge_search tool responses
if tc.tool_name == DEFAULT_RAG_TOOL and resp and response_content:
self._extract_rag_chunks_from_response(response_content)

def _extract_rag_chunks_from_response(self, response_content: str) -> None:
"""Extract RAG chunks from tool response content."""
try:
# Parse the response to get chunks
# Try JSON first
try:
data = json.loads(response_content)
if isinstance(data, dict) and "chunks" in data:
for chunk in data["chunks"]:
self.rag_chunks.append(
RAGChunk(
content=chunk.get("content", ""),
source=chunk.get("source"),
score=chunk.get("score"),
)
)
elif isinstance(data, list):
# Handle list of chunks
for chunk in data:
if isinstance(chunk, dict):
self.rag_chunks.append(
RAGChunk(
content=chunk.get("content", str(chunk)),
source=chunk.get("source"),
score=chunk.get("score"),
)
)
except json.JSONDecodeError:
# If not JSON, treat the entire response as a single chunk
if response_content.strip():
self.rag_chunks.append(
RAGChunk(
content=response_content,
source=DEFAULT_RAG_TOOL,
score=None,
)
)
except (KeyError, AttributeError, TypeError, ValueError):
# Treat response as single chunk on data access/structure errors
if response_content.strip():
self.rag_chunks.append(
RAGChunk(
content=response_content, source=DEFAULT_RAG_TOOL, score=None
)
)
Loading
Loading