In [1]:
# Naive RAG System
# This notebook implements a basic RAG system using the pre-built database from data_setup.ipynb

# Load all required Libraries
import pandas as pd
import transformers, torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, AutoModelForCausalLM
import gc
from pymilvus import MilvusClient, FieldSchema, CollectionSchema, DataType
from concurrent.futures import ThreadPoolExecutor, as_completed
import time
from tqdm import tqdm
import numpy as np
from sentence_transformers import SentenceTransformer


  from .autonotebook import tqdm as notebook_tqdm


# Load Pre-built Database and Data

**Prerequisites**: Run `data_setup.ipynb` first to create the database and embeddings.


In [3]:
# Load pre-built data and database
print("Loading pre-built database and data...")

# Load queries from saved CSV
queries = pd.read_csv("../data/processed/queries.csv")
print(f"Loaded {len(queries)} queries")

# Load pre-computed query embeddings
query_embeddings = np.load("../data/processed/query_embeddings.npy")
print(f"Loaded query embeddings: {query_embeddings.shape}")

# Initialize embedding model (for consistency)
embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
print("Embedding model loaded")

# Check if database file exists
import os
db_path = "../data/processed/rag_wikipedia_mini.db"

print(f"✅ Database file found at {db_path}")

client = MilvusClient(db_path)
print("Connected to existing Milvus database")

Loading pre-built database and data...
Loaded 918 queries
Loaded query embeddings: (918, 384)
Embedding model loaded
✅ Database file found at ../data/processed/rag_wikipedia_mini.db
Connected to existing Milvus database


In [4]:
# Define search function for vector database
def search_and_fetch_top_n_passages(query_emb, limit=3):
    """
    Search for similar passages in the vector database
    
    Args:
        query_emb: Query embedding vector
        limit: Number of top results to return
    
    Returns:
        Search results from Milvus
    """
    search_params = {
        "metric_type": "COSINE",
        "params": {"nprobe": 10}
    }
    
    output_ = client.search(
        collection_name="rag_mini",
        data=[query_emb.tolist()],
        anns_field="embedding",
        search_params=search_params,
        limit=limit,
        output_fields=["passage"]
    )
    return output_

print("Search function defined for vector database queries")


Search function defined for vector database queries


# RAG Response for a Single Query


In [5]:
# Load the LLM Model you want to use
# Using a smaller model
model_name = "google/flan-t5-base"

try:
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForSeq2SeqLM.from_pretrained(
        model_name,
        dtype=torch.float32
    )
    
    print(f"Loaded model: {model_name}")
    print(f"Model parameters: {model.num_parameters():,}")
    
except Exception as e:
    print(f"Failed to load model: {e}")
    raise


Loaded model: google/flan-t5-base
Model parameters: 247,577,856


In [6]:
# Test with a single query
query = queries['question'].iloc[0]  # First query
query_embedding = query_embeddings[0]

print(f"Test query: {query}")

# Search for similar passages
search_results = search_and_fetch_top_n_passages(query_embedding, 3)

# Extract context from search results
top_3_passages = []
for i in range(min(3, len(search_results[0]))):
    top_3_passages.append(search_results[0][i]['entity']['passage'])

context = "\n\n".join(top_3_passages)
print(f"\nContext (top 3 passages):")
for i, passage in enumerate(top_3_passages):
    print(f"\nPassage {i+1}: {passage[:100]}...")

# Create prompt
system_prompt = """You are a helpful assistant that answers questions based on the provided context. 
Use only the information from the context to answer the question. 
If the context doesn't contain enough information to answer the question, say "not enough information".
Be concise and accurate in your response."""

prompt = f"""{system_prompt}\n
Context: {context}\n
Question: {query}"""

print(f"\nPrompt: {prompt[:200]}...")


Test query: Was Abraham Lincoln the sixteenth President of the United States?

Context (top 3 passages):

Passage 1: Young Abraham Lincoln...

Passage 2: Abraham Lincoln (February 12, 1809 â April 15, 1865) was the sixteenth President of the United Sta...

Passage 3: Sixteen months before his death, his son, John Quincy Adams, became the sixth President of the Unite...

Prompt: You are a helpful assistant that answers questions based on the provided context. 
Use only the information from the context to answer the question. 
If the context doesn't contain enough information ...


In [7]:
# Generate answer with proper memory management
try:
    inputs = tokenizer(prompt, return_tensors="pt", max_length=512, truncation=True)
        
    with torch.no_grad():
        outputs = model.generate(
            inputs.input_ids,
            max_length=150,
            num_beams=4,
            early_stopping=True,
            do_sample=False
        )
    
    # Decode and extract answer
    answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
    print(f"Query: {query}")
    print(f"Generated Answer: {answer}")
    
    # Clear memory after generation
    del inputs, outputs
    gc.collect()
        
except Exception as e:
    print(f"Answer generation failed: {e}")
    raise


Query: Was Abraham Lincoln the sixteenth President of the United States?
Generated Answer: yes.


In [9]:
# Multiprocessing RAG Functions with Resume Capability
def process_single_query(args):
    """
    Process a single query through the RAG pipeline
    
    Args:
        args: tuple containing (query_idx, question, embedding, client, tokenizer, model, system_prompt, n, existing_answers)
    
    Returns:
        tuple: (query_idx, generated_answer, combined_context, success, skipped)
    """
    query_idx, question, embedding, client, tokenizer, model, system_prompt, n, existing_answers = args
    
    # Check if query already has an answer
    if query_idx < len(existing_answers) and existing_answers[query_idx] and existing_answers[query_idx] != "":
        return (query_idx, existing_answers[query_idx], "", True, True)  # skipped=True
    
    try:
        # Search for similar passages
        search_results = search_and_fetch_top_n_passages(embedding, n)
        
        # Extract top n passages as context
        top_n_passages = []
        for i in range(min(n, len(search_results[0]))):
            top_n_passages.append(search_results[0][i]['entity']['passage'])
        
        # Combine all contexts into a single string
        combined_context = "\n\n".join(top_n_passages)
        
        # Create prompt
        prompt = f"""{system_prompt}\n
        Context: {combined_context}\n
        Question: {question}"""
        
        # Generate answer
        inputs = tokenizer(prompt, return_tensors="pt", max_length=1024, truncation=True)
        
        with torch.no_grad():
            outputs = model.generate(
                inputs.input_ids,
                max_length=150,
                num_beams=4,
                early_stopping=True,
                do_sample=False
            )
        
        # Decode the generated answer
        answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
        
        # Clear memory
        del inputs, outputs
        gc.collect()
        
        return (query_idx, answer, combined_context, True, False)  # skipped=False
        
    except Exception as e:
        print(f"Error processing query {query_idx + 1}: {e}")
        return (query_idx, "Error generating answer", "", False, False)

def process_queries_parallel(queries, query_embeddings, client, tokenizer, model, system_prompt, n=3, max_workers=4, save_interval=50):
    """
    Process all queries in parallel using ThreadPoolExecutor with resume capability
    
    Args:
        queries: DataFrame with questions
        query_embeddings: numpy array of query embeddings
        client: Milvus client
        tokenizer: Model tokenizer
        model: Language model
        system_prompt: System prompt string
        n: Number of top contexts to retrieve
        max_workers: Number of parallel workers
        save_interval: Save progress every N queries
    
    Returns:
        tuple: (generated_answers, combined_contexts, success_count, skipped_count)
    """
    print(f"Starting parallel processing with {max_workers} workers...")
    print(f"Processing {len(queries)} queries...")
    
    # Check for existing results
    existing_answers = []
    existing_contexts = []
    
    if 'generated_answer' in queries.columns:
        existing_answers = queries['generated_answer'].fillna("").tolist()
        print(f"Found {sum(1 for ans in existing_answers if ans and ans != '')} existing answers")
    
    if 'combined_context' in queries.columns:
        existing_contexts = queries['combined_context'].fillna("").tolist()
    
    # Prepare arguments for each query
    query_args = []
    for idx, (question, embedding) in enumerate(zip(queries['question'], query_embeddings)):
        query_args.append((idx, question, embedding, client, tokenizer, model, system_prompt, n, existing_answers))
    
    # Initialize results storage
    generated_answers = existing_answers[:] if existing_answers else [""] * len(queries)
    combined_contexts = existing_contexts[:] if existing_contexts else [""] * len(queries)
    success_count = 0
    skipped_count = 0
    
    # Process queries in parallel
    start_time = time.time()
    
    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        # Submit all tasks
        future_to_idx = {executor.submit(process_single_query, args): args[0] for args in query_args}
        
        # Process completed tasks with progress bar
        with tqdm(total=len(queries), desc="Processing queries") as pbar:
            for i, future in enumerate(as_completed(future_to_idx)):
                query_idx, answer, context, success, skipped = future.result()
                
                # Store results
                generated_answers[query_idx] = answer
                if context:  # Only update context if we have a new one
                    combined_contexts[query_idx] = context
                
                if success:
                    success_count += 1
                if skipped:
                    skipped_count += 1
                
                pbar.update(1)
                
                # Periodic saving to prevent data loss
                if (i + 1) % save_interval == 0:
                    temp_df = queries.copy()
                    temp_df['generated_answer'] = generated_answers
                    temp_df['combined_context'] = combined_contexts
                    temp_df.to_csv("../results/rag_generated_answers_temp.csv", index=False)
                    print(f"\n💾 Progress saved at query {i + 1}")
    
    end_time = time.time()
    processing_time = end_time - start_time
    
    print(f"\nParallel processing completed!")
    print(f"Total time: {processing_time:.2f} seconds")
    print(f"Average time per query: {processing_time/len(queries):.2f} seconds")
    print(f"Successful queries: {success_count}/{len(queries)}")
    print(f"Skipped queries: {skipped_count}/{len(queries)}")
    print(f"Speed improvement: ~{max_workers}x faster than sequential processing")
    
    return generated_answers, combined_contexts, success_count, skipped_count


# Generate Responses for all the Queries in the Dataset


In [10]:
# Run process_queries_parallel

generated_answers, combined_contexts, success_count, skipped_count = process_queries_parallel(
    queries,
    query_embeddings,
    client,
    tokenizer,
    model,
    system_prompt,
    n=3,
    max_workers=4,
    save_interval=50
)

I0000 00:00:1759531997.669011 2883788 fork_posix.cc:71] Other threads are currently calling into gRPC, skipping fork() handlers


Starting parallel processing with 4 workers...
Processing 918 queries...


Processing queries:   0%|          | 0/918 [00:00<?, ?it/s]huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
Processing queries:   5%|▌         | 50/918 [01:56<08:37,  1.68it/s]  


💾 Progress saved at query 50


Processing queries:  11%|█         | 100/918 [03:24<17:42,  1.30s/it]


💾 Progress saved at query 100


Processing queries:  16%|█▋        | 150/918 [05:12<20:33,  1.61s/it]  


💾 Progress saved at query 150


Processing queries:  22%|██▏       | 200/918 [07:06<24:02,  2.01s/it]


💾 Progress saved at query 200


Processing queries:  27%|██▋       | 250/918 [08:26<15:51,  1.42s/it]


💾 Progress saved at query 250


Processing queries:  33%|███▎      | 300/918 [09:49<14:53,  1.45s/it]


💾 Progress saved at query 300


Processing queries:  38%|███▊      | 350/918 [11:05<16:28,  1.74s/it]


💾 Progress saved at query 350


Processing queries:  44%|████▎     | 400/918 [13:01<08:37,  1.00it/s]


💾 Progress saved at query 400


Processing queries:  49%|████▉     | 450/918 [14:27<05:58,  1.31it/s]


💾 Progress saved at query 450


Processing queries:  54%|█████▍    | 500/918 [16:05<10:01,  1.44s/it]


💾 Progress saved at query 500


Processing queries:  60%|█████▉    | 550/918 [17:21<05:00,  1.23it/s]


💾 Progress saved at query 550


Processing queries:  65%|██████▌   | 600/918 [18:18<03:04,  1.72it/s]


💾 Progress saved at query 600


Processing queries:  71%|███████   | 650/918 [20:30<10:34,  2.37s/it]


💾 Progress saved at query 650


Processing queries:  76%|███████▋  | 700/918 [22:38<17:33,  4.83s/it]


💾 Progress saved at query 700


Processing queries:  82%|████████▏ | 750/918 [25:06<06:46,  2.42s/it]


💾 Progress saved at query 750


Processing queries:  87%|████████▋ | 801/918 [27:33<07:03,  3.62s/it]


💾 Progress saved at query 800


Processing queries:  93%|█████████▎| 850/918 [30:26<05:49,  5.13s/it]


💾 Progress saved at query 850


Processing queries:  98%|█████████▊| 900/918 [32:38<01:29,  4.97s/it]


💾 Progress saved at query 900


Processing queries: 100%|██████████| 918/918 [33:11<00:00,  2.17s/it]


Parallel processing completed!
Total time: 1991.39 seconds
Average time per query: 2.17 seconds
Successful queries: 918/918
Skipped queries: 0/918
Speed improvement: ~4x faster than sequential processing





In [11]:
# Save final results
queries['generated_answer'] = generated_answers
queries['combined_context'] = combined_contexts
queries.to_csv("../results/naive_rag_answers.csv", index=False)
print(f"\n✅ Final results saved to ../results/naive_rag_answers.csv")


✅ Final results saved to ../results/naive_rag_answers.csv


In [12]:
# Remove temporary file if exists
temp_file_path = "../results/rag_generated_answers_temp.csv"
if os.path.exists(temp_file_path):
    os.remove(temp_file_path)
    print(f"🗑️ Removed temporary file: {temp_file_path}")

🗑️ Removed temporary file: ../results/rag_generated_answers_temp.csv
