Skip to content

Commit 7173f72

Browse files
authored
Merge pull request #550 from Anxhela21/anx/rag-context-retrieve-pr
LCORE-601: Add RAG chunks in query response
2 parents fc7cbe1 + 930cd74 commit 7173f72

File tree

10 files changed

+400
-25
lines changed

10 files changed

+400
-25
lines changed

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ dependencies = [
4949
# Used by Llama Stack version checker
5050
"semver<4.0.0",
5151
# Used by authorization resolvers
52-
"jsonpath-ng>=1.6.1",
52+
"jsonpath-ng>=1.6.1"
5353
]
5454

5555

@@ -176,6 +176,7 @@ addopts = [
176176

177177
[tool.pylint.main]
178178
source-roots = "src"
179+
ignore = ["query.py"]
179180

180181
[build-system]
181182
requires = ["pdm-backend"]

run.yaml

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,9 @@ providers:
6060
provider_id: meta-reference
6161
provider_type: inline::meta-reference
6262
inference:
63+
- provider_id: sentence-transformers # Can be any embedding provider
64+
provider_type: inline::sentence-transformers
65+
config: {}
6366
- provider_id: openai
6467
provider_type: remote::openai
6568
config:
@@ -99,14 +102,17 @@ providers:
99102
- provider_id: model-context-protocol
100103
provider_type: remote::model-context-protocol
101104
config: {}
105+
- provider_id: rag-runtime
106+
provider_type: inline::rag-runtime
107+
config: {}
102108
vector_io:
103109
- config:
104110
kvstore:
105-
db_path: .llama/distributions/ollama/faiss_store.db
111+
db_path: .llama/distributions/ollama/faiss_store.db # Location of vector database
106112
namespace: null
107113
type: sqlite
108114
provider_id: faiss
109-
provider_type: inline::faiss
115+
provider_type: inline::faiss # Or preferred vector DB
110116
scoring_fns: []
111117
server:
112118
auth: null
@@ -117,10 +123,23 @@ server:
117123
tls_certfile: null
118124
tls_keyfile: null
119125
shields: []
120-
vector_dbs: []
121-
126+
vector_dbs:
127+
- vector_db_id: my_knowledge_base
128+
embedding_model: sentence-transformers/all-mpnet-base-v2
129+
embedding_dimension: 768
130+
provider_id: faiss
122131
models:
132+
- metadata:
133+
embedding_dimension: 768 # Depends on chosen model
134+
model_id: sentence-transformers/all-mpnet-base-v2 # Example embedding model
135+
provider_id: sentence-transformers
136+
provider_model_id: sentence-transformers/all-mpnet-base-v2 # Location of embedding model
137+
model_type: embedding
123138
- model_id: gpt-4-turbo
124139
provider_id: openai
125140
model_type: llm
126141
provider_model_id: gpt-4-turbo
142+
143+
tool_groups:
144+
- toolgroup_id: builtin::rag
145+
provider_id: rag-runtime

src/app/endpoints/query.py

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from typing import Annotated, Any, Optional, cast
99

1010
from fastapi import APIRouter, Depends, HTTPException, Request, status
11+
from pydantic import AnyUrl
1112
from llama_stack_client import (
1213
APIConnectionError,
1314
AsyncLlamaStackClient, # type: ignore
@@ -39,6 +40,7 @@
3940
ForbiddenResponse,
4041
QueryResponse,
4142
ReferencedDocument,
43+
ToolCall,
4244
UnauthorizedResponse,
4345
)
4446
from utils.endpoints import (
@@ -248,6 +250,10 @@ async def query_endpoint_handler(
248250
# Update metrics for the LLM call
249251
metrics.llm_calls_total.labels(provider_id, model_id).inc()
250252

253+
# Convert RAG chunks to dictionary format once for reuse
254+
logger.info("Processing RAG chunks...")
255+
rag_chunks_dict = [chunk.model_dump() for chunk in summary.rag_chunks]
256+
251257
if not is_transcripts_enabled():
252258
logger.debug("Transcript collection is disabled in the configuration")
253259
else:
@@ -260,23 +266,63 @@ async def query_endpoint_handler(
260266
query=query_request.query,
261267
query_request=query_request,
262268
summary=summary,
263-
rag_chunks=[], # TODO(lucasagomes): implement rag_chunks
269+
rag_chunks=rag_chunks_dict,
264270
truncated=False, # TODO(lucasagomes): implement truncation as part of quota work
265271
attachments=query_request.attachments or [],
266272
)
267273

274+
logger.info("Persisting conversation details...")
268275
persist_user_conversation_details(
269276
user_id=user_id,
270277
conversation_id=conversation_id,
271278
model=model_id,
272279
provider_id=provider_id,
273280
)
274281

275-
return QueryResponse(
282+
# Convert tool calls to response format
283+
logger.info("Processing tool calls...")
284+
tool_calls = [
285+
ToolCall(
286+
tool_name=tc.name,
287+
arguments=(
288+
tc.args if isinstance(tc.args, dict) else {"query": str(tc.args)}
289+
),
290+
result=(
291+
{"response": tc.response}
292+
if tc.response and tc.name != constants.DEFAULT_RAG_TOOL
293+
else None
294+
),
295+
)
296+
for tc in summary.tool_calls
297+
]
298+
299+
logger.info("Extracting referenced documents...")
300+
referenced_docs = []
301+
doc_sources = set()
302+
for chunk in summary.rag_chunks:
303+
if chunk.source and chunk.source not in doc_sources:
304+
doc_sources.add(chunk.source)
305+
referenced_docs.append(
306+
ReferencedDocument(
307+
doc_url=(
308+
AnyUrl(chunk.source)
309+
if chunk.source.startswith("http")
310+
else None
311+
),
312+
doc_title=chunk.source,
313+
)
314+
)
315+
316+
logger.info("Building final response...")
317+
response = QueryResponse(
276318
conversation_id=conversation_id,
277319
response=summary.llm_response,
320+
rag_chunks=summary.rag_chunks if summary.rag_chunks else [],
321+
tool_calls=tool_calls if tool_calls else None,
278322
referenced_documents=referenced_documents,
279323
)
324+
logger.info("Query processing completed successfully!")
325+
return response
280326

281327
# connection to Llama Stack server
282328
except APIConnectionError as e:

src/constants.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,8 @@
5252
DEFAULT_JWT_UID_CLAIM = "user_id"
5353
DEFAULT_JWT_USER_NAME_CLAIM = "username"
5454

55+
# default RAG tool value
56+
DEFAULT_RAG_TOOL = "knowledge_search"
5557

5658
# PostgreSQL connection constants
5759
# See: https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNECT-SSLMODE

src/models/responses.py

Lines changed: 47 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,22 @@ class ModelsResponse(BaseModel):
3434
)
3535

3636

37+
class RAGChunk(BaseModel):
38+
"""Model representing a RAG chunk used in the response."""
39+
40+
content: str = Field(description="The content of the chunk")
41+
source: Optional[str] = Field(None, description="Source document or URL")
42+
score: Optional[float] = Field(None, description="Relevance score")
43+
44+
45+
class ToolCall(BaseModel):
46+
"""Model representing a tool call made during response generation."""
47+
48+
tool_name: str = Field(description="Name of the tool called")
49+
arguments: dict[str, Any] = Field(description="Arguments passed to the tool")
50+
result: Optional[dict[str, Any]] = Field(None, description="Result from the tool")
51+
52+
3753
class ReferencedDocument(BaseModel):
3854
"""Model representing a document referenced in generating a response.
3955
@@ -42,27 +58,27 @@ class ReferencedDocument(BaseModel):
4258
doc_title: Title of the referenced doc.
4359
"""
4460

45-
doc_url: AnyUrl = Field(description="URL of the referenced document")
61+
doc_url: Optional[AnyUrl] = Field(
62+
None, description="URL of the referenced document"
63+
)
4664

4765
doc_title: str = Field(description="Title of the referenced document")
4866

4967

50-
# TODO(lucasagomes): a lot of fields to add to QueryResponse. For now
51-
# we are keeping it simple. The missing fields are:
52-
# - truncated: Set to True if conversation history was truncated to be within context window.
53-
# - input_tokens: Number of tokens sent to LLM
54-
# - output_tokens: Number of tokens received from LLM
55-
# - available_quotas: Quota available as measured by all configured quota limiters
56-
# - tool_calls: List of tool requests.
57-
# - tool_results: List of tool results.
58-
# See LLMResponse in ols-service for more details.
5968
class QueryResponse(BaseModel):
6069
"""Model representing LLM response to a query.
6170
6271
Attributes:
6372
conversation_id: The optional conversation ID (UUID).
6473
response: The response.
74+
rag_chunks: List of RAG chunks used to generate the response.
6575
referenced_documents: The URLs and titles for the documents used to generate the response.
76+
tool_calls: List of tool calls made during response generation.
77+
TODO: truncated: Whether conversation history was truncated.
78+
TODO: input_tokens: Number of tokens sent to LLM.
79+
TODO: output_tokens: Number of tokens received from LLM.
80+
TODO: available_quotas: Quota available as measured by all configured quota limiters
81+
TODO: tool_results: List of tool results.
6682
"""
6783

6884
conversation_id: Optional[str] = Field(
@@ -78,6 +94,13 @@ class QueryResponse(BaseModel):
7894
],
7995
)
8096

97+
rag_chunks: list[RAGChunk] = []
98+
99+
tool_calls: Optional[list[ToolCall]] = Field(
100+
None,
101+
description="List of tool calls made during response generation",
102+
)
103+
81104
referenced_documents: list[ReferencedDocument] = Field(
82105
default_factory=list,
83106
description="List of documents referenced in generating the response",
@@ -99,6 +122,20 @@ class QueryResponse(BaseModel):
99122
{
100123
"conversation_id": "123e4567-e89b-12d3-a456-426614174000",
101124
"response": "Operator Lifecycle Manager (OLM) helps users install...",
125+
"rag_chunks": [
126+
{
127+
"content": "OLM is a component of the Operator Framework toolkit...",
128+
"source": "kubernetes-docs/operators.md",
129+
"score": 0.95,
130+
}
131+
],
132+
"tool_calls": [
133+
{
134+
"tool_name": "knowledge_search",
135+
"arguments": {"query": "operator lifecycle manager"},
136+
"result": {"chunks_found": 5},
137+
}
138+
],
102139
"referenced_documents": [
103140
{
104141
"doc_url": "https://docs.openshift.com/"

src/utils/transcripts.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def store_transcript( # pylint: disable=too-many-arguments,too-many-positional-
3939
query: str,
4040
query_request: QueryRequest,
4141
summary: TurnSummary,
42-
rag_chunks: list[str],
42+
rag_chunks: list[dict],
4343
truncated: bool,
4444
attachments: list[Attachment],
4545
) -> None:
@@ -52,7 +52,7 @@ def store_transcript( # pylint: disable=too-many-arguments,too-many-positional-
5252
query: The query (without attachments).
5353
query_request: The request containing a query.
5454
summary: Summary of the query/response turn.
55-
rag_chunks: The list of `RagChunk` objects.
55+
rag_chunks: The list of serialized `RAGChunk` dictionaries.
5656
truncated: The flag indicating if the history was truncated.
5757
attachments: The list of `Attachment` objects.
5858
"""

src/utils/types.py

Lines changed: 60 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
"""Common types for the project."""
22

33
from typing import Any, Optional
4-
4+
import json
55
from llama_stack_client.lib.agents.event_logger import interleaved_content_as_str
66
from llama_stack_client.lib.agents.tool_parser import ToolParser
77
from llama_stack_client.types.shared.completion_message import CompletionMessage
88
from llama_stack_client.types.shared.tool_call import ToolCall
99
from llama_stack_client.types.tool_execution_step import ToolExecutionStep
10-
from pydantic.main import BaseModel
10+
from pydantic import BaseModel
11+
from models.responses import RAGChunk
12+
from constants import DEFAULT_RAG_TOOL
1113

1214

1315
class Singleton(type):
@@ -61,18 +63,73 @@ class TurnSummary(BaseModel):
6163

6264
llm_response: str
6365
tool_calls: list[ToolCallSummary]
66+
rag_chunks: list[RAGChunk] = []
6467

6568
def append_tool_calls_from_llama(self, tec: ToolExecutionStep) -> None:
6669
"""Append the tool calls from a llama tool execution step."""
6770
calls_by_id = {tc.call_id: tc for tc in tec.tool_calls}
6871
responses_by_id = {tc.call_id: tc for tc in tec.tool_responses}
6972
for call_id, tc in calls_by_id.items():
7073
resp = responses_by_id.get(call_id)
74+
response_content = (
75+
interleaved_content_as_str(resp.content) if resp else None
76+
)
77+
7178
self.tool_calls.append(
7279
ToolCallSummary(
7380
id=call_id,
7481
name=tc.tool_name,
7582
args=tc.arguments,
76-
response=interleaved_content_as_str(resp.content) if resp else None,
83+
response=response_content,
7784
)
7885
)
86+
87+
# Extract RAG chunks from knowledge_search tool responses
88+
if tc.tool_name == DEFAULT_RAG_TOOL and resp and response_content:
89+
self._extract_rag_chunks_from_response(response_content)
90+
91+
def _extract_rag_chunks_from_response(self, response_content: str) -> None:
92+
"""Extract RAG chunks from tool response content."""
93+
try:
94+
# Parse the response to get chunks
95+
# Try JSON first
96+
try:
97+
data = json.loads(response_content)
98+
if isinstance(data, dict) and "chunks" in data:
99+
for chunk in data["chunks"]:
100+
self.rag_chunks.append(
101+
RAGChunk(
102+
content=chunk.get("content", ""),
103+
source=chunk.get("source"),
104+
score=chunk.get("score"),
105+
)
106+
)
107+
elif isinstance(data, list):
108+
# Handle list of chunks
109+
for chunk in data:
110+
if isinstance(chunk, dict):
111+
self.rag_chunks.append(
112+
RAGChunk(
113+
content=chunk.get("content", str(chunk)),
114+
source=chunk.get("source"),
115+
score=chunk.get("score"),
116+
)
117+
)
118+
except json.JSONDecodeError:
119+
# If not JSON, treat the entire response as a single chunk
120+
if response_content.strip():
121+
self.rag_chunks.append(
122+
RAGChunk(
123+
content=response_content,
124+
source=DEFAULT_RAG_TOOL,
125+
score=None,
126+
)
127+
)
128+
except (KeyError, AttributeError, TypeError, ValueError):
129+
# Treat response as single chunk on data access/structure errors
130+
if response_content.strip():
131+
self.rag_chunks.append(
132+
RAGChunk(
133+
content=response_content, source=DEFAULT_RAG_TOOL, score=None
134+
)
135+
)

0 commit comments

Comments
 (0)