Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 124 additions & 14 deletions scripts/query_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import argparse
import importlib
import json
import logging
import os
import sys
import tempfile
Expand Down Expand Up @@ -30,23 +32,76 @@ def _llama_index_query(args: argparse.Namespace) -> None:
storage_context=storage_context,
index_id=args.product_index,
)

if args.node is not None:
print(storage_context.docstore.get_node(args.node))
node = storage_context.docstore.get_node(args.node)
result = {
"query": args.query,
"type": "single_node",
"node_id": args.node,
"node": {
"id": node.node_id,
"text": node.text,
"metadata": node.metadata if hasattr(node, 'metadata') else {}
}
}
if args.json:
print(json.dumps(result, indent=2))
else:
print(node)
else:
retriever = vector_index.as_retriever(similarity_top_k=args.top_k)
nodes = retriever.retrieve(args.query)

if len(nodes) == 0:
print(f"No nodes retrieved for query: {args.query}")
logging.warning(f"No nodes retrieved for query: {args.query}")
if args.json:
result = {
"query": args.query,
"top_k": args.top_k,
"threshold": args.threshold,
"nodes": []
}
print(json.dumps(result, indent=2))
exit(1)

if args.threshold > 0.0 and nodes[0].score < args.threshold:
print(
logging.warning(
f"Score {nodes[0].score} of the top retrieved node for query '{args.query}' "
f"didn't cross the minimal threshold {args.threshold}."
)
if args.json:
result = {
"query": args.query,
"top_k": args.top_k,
"threshold": args.threshold,
"nodes": []
}
print(json.dumps(result, indent=2))
exit(1)
for n in nodes:
print("=" * 80)
print(n)

# Format results
result = {
"query": args.query,
"top_k": args.top_k,
"threshold": args.threshold,
"nodes": []
}
for node in nodes:
node_data = {
"id": node.node_id,
"score": node.score,
"text": node.text,
"metadata": node.metadata if hasattr(node, 'metadata') else {}
}
result["nodes"].append(node_data)

if args.json:
print(json.dumps(result, indent=2))
else:
for n in nodes:
print("=" * 80)
print(n)


def _get_db_path_dict(vector_type: str, config: dict[str, Any]) -> dict[str, Any]:
Expand Down Expand Up @@ -108,20 +163,56 @@ def _llama_stack_query(args: argparse.Namespace) -> None:

md = res.metadata
if len(md["chunks"]) == 0:
print(f"No chunks retrieved for query: {args.query}")
logging.warning(f"No chunks retrieved for query: {args.query}")
if args.json:
result = {
"query": args.query,
"top_k": args.top_k,
"threshold": args.threshold,
"nodes": []
}
print(json.dumps(result, indent=2))
exit(1)

threshold = args.threshold
if threshold > 0.0 and md.get("scores") and md["scores"][0].score < threshold:
print(
logging.warning(
f"Score {md['scores'][0].score} of the top retrieved node for query '{args.query}' "
f"didn't cross the minimal threshold {threshold}."
)
if args.json:
result = {
"query": args.query,
"top_k": args.top_k,
"threshold": args.threshold,
"nodes": []
}
print(json.dumps(result, indent=2))
exit(1)

# Method 1 to present data:
# Format results
result = {
"query": args.query,
"top_k": args.top_k,
"threshold": args.threshold,
"nodes": []
}

for _id, chunk, score in zip(md["document_ids"], md["chunks"], md["scores"]):
print("=" * 80)
print(f"Node ID: {_id}\nScore: {score}\nText:\n{chunk}")
node_data = {
"id": _id,
"score": score.score if hasattr(score, 'score') else score,
"text": chunk,
"metadata": {}
}
result["nodes"].append(node_data)

if args.json:
print(json.dumps(result, indent=2))
else:
for _id, chunk, score in zip(md["document_ids"], md["chunks"], md["scores"]):
print("=" * 80)
print(f"Node ID: {_id}\nScore: {score}\nText:\n{chunk}")

# Method 2 to present data:
# for content in res.content:
Expand All @@ -130,7 +221,6 @@ def _llama_stack_query(args: argparse.Namespace) -> None:
# else:
# print(content)


if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Utility script for querying RAG database"
Expand Down Expand Up @@ -161,10 +251,30 @@ def _llama_stack_query(args: argparse.Namespace) -> None:
choices=["auto", "faiss", "llamastack-faiss", "llamastack-sqlite-vec"],
help="vector store type to be used.",
)
parser.add_argument(
"--json",
action="store_true",
help="Output results in JSON format",
)

args = parser.parse_args()

print("Command line used: " + " ".join(sys.argv))
if args.json:
# In JSON mode, only show ERROR or higher to avoid polluting JSON output
logging.basicConfig(
level=logging.ERROR,
format='%(levelname)s: %(message)s',
stream=sys.stderr # Send logs to stderr to keep stdout clean for JSON
)
else:
# In normal mode, show info and above
logging.basicConfig(
level=logging.INFO,
format='%(message)s'
)

if not args.json:
logging.info("Command line used: " + " ".join(sys.argv))

vector_store_type = args.vector_store_type
if args.vector_store_type == "auto":
Expand All @@ -175,7 +285,7 @@ def _llama_stack_query(args: argparse.Namespace) -> None:
elif os.path.exists(os.path.join(args.db_path, "faiss_store.db")):
args.vector_store_type = "llamastack-faiss"
else:
print("Cannot recognize the DB in", args.db_path)
logging.error(f"Cannot recognize the DB in {args.db_path}")
exit(1)

if args.vector_store_type == "faiss":
Expand Down