diff --git a/scripts/query_rag.py b/scripts/query_rag.py index 920242a..d2038eb 100644 --- a/scripts/query_rag.py +++ b/scripts/query_rag.py @@ -2,6 +2,8 @@ import argparse import importlib +import json +import logging import os import sys import tempfile @@ -30,23 +32,76 @@ def _llama_index_query(args: argparse.Namespace) -> None: storage_context=storage_context, index_id=args.product_index, ) + if args.node is not None: - print(storage_context.docstore.get_node(args.node)) + node = storage_context.docstore.get_node(args.node) + result = { + "query": args.query, + "type": "single_node", + "node_id": args.node, + "node": { + "id": node.node_id, + "text": node.text, + "metadata": node.metadata if hasattr(node, 'metadata') else {} + } + } + if args.json: + print(json.dumps(result, indent=2)) + else: + print(node) else: retriever = vector_index.as_retriever(similarity_top_k=args.top_k) nodes = retriever.retrieve(args.query) + if len(nodes) == 0: - print(f"No nodes retrieved for query: {args.query}") + logging.warning(f"No nodes retrieved for query: {args.query}") + if args.json: + result = { + "query": args.query, + "top_k": args.top_k, + "threshold": args.threshold, + "nodes": [] + } + print(json.dumps(result, indent=2)) exit(1) + if args.threshold > 0.0 and nodes[0].score < args.threshold: - print( + logging.warning( f"Score {nodes[0].score} of the top retrieved node for query '{args.query}' " f"didn't cross the minimal threshold {args.threshold}." ) + if args.json: + result = { + "query": args.query, + "top_k": args.top_k, + "threshold": args.threshold, + "nodes": [] + } + print(json.dumps(result, indent=2)) exit(1) - for n in nodes: - print("=" * 80) - print(n) + + # Format results + result = { + "query": args.query, + "top_k": args.top_k, + "threshold": args.threshold, + "nodes": [] + } + for node in nodes: + node_data = { + "id": node.node_id, + "score": node.score, + "text": node.text, + "metadata": node.metadata if hasattr(node, 'metadata') else {} + } + result["nodes"].append(node_data) + + if args.json: + print(json.dumps(result, indent=2)) + else: + for n in nodes: + print("=" * 80) + print(n) def _get_db_path_dict(vector_type: str, config: dict[str, Any]) -> dict[str, Any]: @@ -108,20 +163,56 @@ def _llama_stack_query(args: argparse.Namespace) -> None: md = res.metadata if len(md["chunks"]) == 0: - print(f"No chunks retrieved for query: {args.query}") + logging.warning(f"No chunks retrieved for query: {args.query}") + if args.json: + result = { + "query": args.query, + "top_k": args.top_k, + "threshold": args.threshold, + "nodes": [] + } + print(json.dumps(result, indent=2)) exit(1) + threshold = args.threshold if threshold > 0.0 and md.get("scores") and md["scores"][0].score < threshold: - print( + logging.warning( f"Score {md['scores'][0].score} of the top retrieved node for query '{args.query}' " f"didn't cross the minimal threshold {threshold}." ) + if args.json: + result = { + "query": args.query, + "top_k": args.top_k, + "threshold": args.threshold, + "nodes": [] + } + print(json.dumps(result, indent=2)) exit(1) - # Method 1 to present data: + # Format results + result = { + "query": args.query, + "top_k": args.top_k, + "threshold": args.threshold, + "nodes": [] + } + for _id, chunk, score in zip(md["document_ids"], md["chunks"], md["scores"]): - print("=" * 80) - print(f"Node ID: {_id}\nScore: {score}\nText:\n{chunk}") + node_data = { + "id": _id, + "score": score.score if hasattr(score, 'score') else score, + "text": chunk, + "metadata": {} + } + result["nodes"].append(node_data) + + if args.json: + print(json.dumps(result, indent=2)) + else: + for _id, chunk, score in zip(md["document_ids"], md["chunks"], md["scores"]): + print("=" * 80) + print(f"Node ID: {_id}\nScore: {score}\nText:\n{chunk}") # Method 2 to present data: # for content in res.content: @@ -130,7 +221,6 @@ def _llama_stack_query(args: argparse.Namespace) -> None: # else: # print(content) - if __name__ == "__main__": parser = argparse.ArgumentParser( description="Utility script for querying RAG database" @@ -161,10 +251,30 @@ def _llama_stack_query(args: argparse.Namespace) -> None: choices=["auto", "faiss", "llamastack-faiss", "llamastack-sqlite-vec"], help="vector store type to be used.", ) + parser.add_argument( + "--json", + action="store_true", + help="Output results in JSON format", + ) args = parser.parse_args() - print("Command line used: " + " ".join(sys.argv)) + if args.json: + # In JSON mode, only show ERROR or higher to avoid polluting JSON output + logging.basicConfig( + level=logging.ERROR, + format='%(levelname)s: %(message)s', + stream=sys.stderr # Send logs to stderr to keep stdout clean for JSON + ) + else: + # In normal mode, show info and above + logging.basicConfig( + level=logging.INFO, + format='%(message)s' + ) + + if not args.json: + logging.info("Command line used: " + " ".join(sys.argv)) vector_store_type = args.vector_store_type if args.vector_store_type == "auto": @@ -175,7 +285,7 @@ def _llama_stack_query(args: argparse.Namespace) -> None: elif os.path.exists(os.path.join(args.db_path, "faiss_store.db")): args.vector_store_type = "llamastack-faiss" else: - print("Cannot recognize the DB in", args.db_path) + logging.error(f"Cannot recognize the DB in {args.db_path}") exit(1) if args.vector_store_type == "faiss":