# RAG testset generation

**Purpose:** Generate test questions for RAG Evaluation

---
**Copyright (c) 2025 Michael Powers**

# Imports & Config

In [19]:
import os
import chromadb
from llama_index.core import Document
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.core.node_parser import SentenceSplitter
import logging
import json
from datetime import datetime
from typing import List, Dict, Any
import google.generativeai as genai
import time

# For Hugging Face Dataset format
from datasets import Dataset

In [2]:
gemini_model = 'gemini-2.5-flash-lite-preview-06-17'
api_key="YOUR_API_KEY"

In [3]:
# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

In [4]:
# --- Configuration ---
CHROMA_DB_PATH = "../application/chroma_db"
SCHEMA_COLLECTION_NAME = "sql_schema_metadata_collection"
BUSINESS_TERMS_COLLECTION_NAME = "business_terms_collection"
HUGGINGFACE_EMBEDDING_MODEL_NAME = "BAAI/bge-small-en-v1.5" # needs to match


# LLM Caller Functions

In [5]:
def ask_gemini_json(prompt, use_json=True, model='models/gemini-2.0-flash-lite'):
    import os
    import google.generativeai as genai
    genai.configure(api_key=api_key)
    model = genai.GenerativeModel(model)
    if use_json:
        generation_config = genai.GenerationConfig(response_mime_type="application/json")
        response = model.generate_content(prompt, generation_config=generation_config)
    else:
        response = model.generate_content(prompt)
    return response.text

In [12]:
def get_prompt(context_text):
    prompt = f"""
Based on the following context, generate ONE question about the content and its corresponding ground truth answer.
The question should be concise and directly answerable from the provided context.
The ground truth answer should be derived directly and accurately from the context.
Focus on questions that someone building a SQL query or understanding business terms would ask.

Output the question and answer in a JSON format with keys "question" and "ground_truth_answer".
Ensure the "ground_truth_answer" is a string, even if it contains a list or structured information.

Context:

{context_text}

Example JSON format:
```json
{{
  "question": "What is the primary key of the 'Customers' table?",
  "ground_truth_answer": "The primary key of the 'Customers' table is 'customer_id'."
}}
```
"""
    return prompt

# Get RAG Documents

In [7]:
def get_embedding_model():
    return HuggingFaceEmbedding(model_name=HUGGINGFACE_EMBEDDING_MODEL_NAME)

In [8]:
def get_all_nodes_from_chroma(collection_name: str) -> List[Document]:
    logger.info(f"Retrieving all nodes from collection: {collection_name} for testset generation...")
    try:
        db = chromadb.PersistentClient(path=CHROMA_DB_PATH)
        chroma_collection = db.get_or_create_collection(collection_name)

        all_ids = chroma_collection.get(ids=chroma_collection.get()['ids'])['ids']
        results = chroma_collection.get(ids=all_ids, include=['documents', 'metadatas'])

        documents = []
        for i in range(len(results['ids'])):
            raw_text_from_chroma = results['documents'][i]
            metadata = results['metadatas'][i]
            doc = Document(text=raw_text_from_chroma, metadata=metadata, id_=results['ids'][i])
            documents.append(doc)
        logger.info(f"Retrieved {len(documents)} nodes from {collection_name}.")
        return documents
    except Exception as e:
        logger.error(f"Error retrieving nodes from ChromaDB collection {collection_name}: {e}")
        return []


# question gen

In [29]:
def generate_questions(documents, test_set_size = 15, model=gemini_model, rpm_limit=15):
    import random

    RPM_LIMIT = rpm_limit
    MAX_RETRIES = 5
    BASE_SLEEP_TIME = 4.5
    
    generated_data_for_ragas = []
    num_generated_questions = 0

    if test_set_size < len(documents):
        documents_to_process = random.sample(documents, test_set_size)
    else:
        documents_to_process = documents

    # Counter for API calls made within the current minute
    requests_in_minute = 0
    start_time_minute = time.time()
    
    
    for i, doc in enumerate(documents_to_process):
        if num_generated_questions >= test_set_size:
            break
        context_text = doc.get_content()
        # Ensure context is not empty
        if not context_text.strip():
            logger.warning(f"Skipping empty context for document {doc.id_}")
            continue
        logger.info(f"Generating question for document {i+1}/{len(documents_to_process)}...")
        
        retries = 0
        while retries < MAX_RETRIES:
            if num_generated_questions >= test_set_size:
                break

            # Check RPM limit
            current_time = time.time()
            if current_time - start_time_minute >= 60:
                requests_in_minute = 0
                start_time_minute = current_time

            if requests_in_minute >= RPM_LIMIT:
                wait_time = 60 - (current_time - start_time_minute)
                print(f"Rate limit hit. Waiting for {wait_time:.2f} seconds...")
                time.sleep(wait_time + 1)
                requests_in_minute = 0
                start_time_minute = time.time()
                
            try: # RESPONSE
                prompt = get_prompt(context_text)
                response = ask_gemini_json(prompt, use_json=True, model=model)
                response = response.strip()
                if response.startswith("```json") and response.endswith("```"):
                    response = response[len("```json"): -len("```")].strip()
                elif response.startswith("```") and response.endswith("```"):
                    response = response[len("```"): -len("```")].strip()
                generated_text = response
                try: # PARSE JSON
                    qa_pair = json.loads(generated_text)
                    question = qa_pair.get("question")
                    ground_truth_answer = qa_pair.get("ground_truth_answer")
                
                    if question and ground_truth_answer:
                        # Append the original document's text as contexts
                        # Ragas expects a list of strings for 'contexts'
                        generated_data_for_ragas.append({
                            "question": question,
                            "ground_truth_answers": [ground_truth_answer], # Ragas expects a list
                            "contexts": [context_text] # Use the full context that the Q&A was derived from
                        })
                        num_generated_questions += 1
                        logger.info(f"Generated Q&A pair {num_generated_questions}/{test_set_size}")
                    else:
                        logger.warning(f"LLM generated incomplete JSON for document {doc.id_}: {generated_text}")
                except json.JSONDecodeError:
                    logger.warning(f"LLM did not generate valid JSON for document {doc.id_}. Response: {generated_text[:500]}...")
                break # Success, break out of retry loop
            except Exception as e:     
                retries += 1
                sleep_duration = BASE_SLEEP_TIME * (2 ** (retries - 1)) + random.uniform(0, 1)
                logger.warning(f"API Error : {e}. Retrying in {sleep_duration:.2f}s... (Attempt {retries}/{MAX_RETRIES})")
                time.sleep(sleep_duration)

            if retries == MAX_RETRIES:
                logger.error(f"Failed to process after {MAX_RETRIES} retries. Skipping.")
                continue # Continue to next document even if one fails
    #### DONE LOOPING
    
    if not generated_data_for_ragas:
        logger.error("No questions were successfully generated. Please review logs and configuration.")
        exit(1)

    logger.info(f"Successfully generated {len(generated_data_for_ragas)} Q&A pairs.")
    hf_dataset = Dataset.from_list(generated_data_for_ragas)

    ### SAVE TO JSON FILE ###
    output_filename = f"ragas_custom_testset_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
    hf_dataset.to_json(output_filename, indent=4)
    logger.info(f"Generated RAG evaluation testset saved to: {output_filename}")

    print(f"\nSuccessfully generated a RAG evaluation testset with {len(hf_dataset)} questions.")
    print("Here's a sample of the generated questions and contexts:")
    for i, entry in enumerate(hf_dataset):
        if i >= 3: # Print first 3 samples
            break
        print(f"\n--- Sample {i+1} ---")
        print(f"Question: {entry['question']}")
        print(f"Ground Truth Answer: {entry['ground_truth_answers'][0]}") 
        print(f"Contexts ({len(entry['contexts'])} nodes):")
        for j, context_text in enumerate(entry['contexts']):
            if j < 1: # Print only the first context for brevity
                print(f"  - Context {j+1}: {context_text[:300]}...") # Truncate for display
            else:
                print(f"  - ... ({len(entry['contexts']) - 1} more contexts)")
                break

# Prep and run

In [27]:
#create embedding
embeddings_model = get_embedding_model()

#retrieve and combine documents
schema_documents = get_all_nodes_from_chroma(SCHEMA_COLLECTION_NAME)
business_terms_documents = get_all_nodes_from_chroma(BUSINESS_TERMS_COLLECTION_NAME)
all_retrieved_documents = schema_documents + business_terms_documents

if not all_retrieved_documents:
    logger.error("No source documents found in ChromaDB collections. Cannot generate test set. Make sure ingest.py ran successfully.")
    exit(1)

2025-07-20 08:21:42,828 - INFO - Load pretrained SentenceTransformer: BAAI/bge-small-en-v1.5
2025-07-20 08:21:45,418 - INFO - 2 prompts are loaded, with the keys: ['query', 'text']
2025-07-20 08:21:45,432 - INFO - Retrieving all nodes from collection: sql_schema_metadata_collection for testset generation...
2025-07-20 08:21:45,651 - INFO - Retrieved 134 nodes from sql_schema_metadata_collection.
2025-07-20 08:21:45,652 - INFO - Retrieving all nodes from collection: business_terms_collection for testset generation...
2025-07-20 08:21:45,813 - INFO - Retrieved 1 nodes from business_terms_collection.


In [30]:
generate_questions(all_retrieved_documents, 50)

2025-07-20 08:28:22,702 - INFO - Generating question for document 1/50...
2025-07-20 08:28:23,539 - INFO - Generated Q&A pair 1/50
2025-07-20 08:28:23,539 - INFO - Generating question for document 2/50...
2025-07-20 08:28:24,718 - INFO - Generated Q&A pair 2/50
2025-07-20 08:28:24,719 - INFO - Generating question for document 3/50...
2025-07-20 08:28:25,633 - INFO - Generated Q&A pair 3/50
2025-07-20 08:28:25,634 - INFO - Generating question for document 4/50...
2025-07-20 08:28:26,332 - INFO - Generated Q&A pair 4/50
2025-07-20 08:28:26,333 - INFO - Generating question for document 5/50...
2025-07-20 08:28:27,159 - INFO - Generated Q&A pair 5/50
2025-07-20 08:28:27,160 - INFO - Generating question for document 6/50...
2025-07-20 08:28:28,896 - INFO - Generated Q&A pair 6/50
2025-07-20 08:28:28,897 - INFO - Generating question for document 7/50...
2025-07-20 08:28:31,862 - INFO - Generated Q&A pair 7/50
2025-07-20 08:28:31,863 - INFO - Generating question for document 8/50...
2025-07-2

Creating json from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

2025-07-20 08:30:40,577 - INFO - Generated RAG evaluation testset saved to: ragas_custom_testset_20250720_083040.json



Successfully generated a RAG evaluation testset with 50 questions.
Here's a sample of the generated questions and contexts:

--- Sample 1 ---
Question: What does the 'PID' column in the 'M_Director' table represent?
Ground Truth Answer: The 'PID' column in the 'M_Director' table represents the person ID of the director.
Contexts (1 nodes):
  - Context 1: Database: DB_IMDB
Table: M_Director
Columns:
  index (INTEGER)
  MID (TEXT)
  PID (TEXT)
  ID (INTEGER)
Sample Rows:
  {'index': 1046, 'MID': 'tt6080746', 'PID': 'nm0223606', 'ID': 1046}
  {'index': 2699, 'MID': 'tt2962230', 'PID': 'nm0154269', 'ID': 2699}...

--- Sample 2 ---
Question: What is the name of the column that stores the quantity of items sold in kilograms?
Ground Truth Answer: The column that stores the quantity of items sold in kilograms is 'qty_sold(kg)'.
Contexts (1 nodes):
  - Context 1: Database: bank_sales_trading
Table: veg_txn_df
Columns:
  index (INTEGER)
  txn_date (TEXT)
  txn_time (TEXT)
  item_code (INTEGER)
