In [1]:
# Enhanced RAG System with Reranking and Query Rewriting
import pandas as pd
import transformers, torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, AutoModelForCausalLM
import gc
from pymilvus import MilvusClient, FieldSchema, CollectionSchema, DataType
from concurrent.futures import ThreadPoolExecutor, as_completed
import time
from tqdm import tqdm
import numpy as np
from sentence_transformers import SentenceTransformer, CrossEncoder
import re
from typing import List, Dict, Tuple

  from .autonotebook import tqdm as notebook_tqdm


# Enhanced RAG System

This notebook implements an enhanced RAG system with two key improvements:

1. **Query Rewriting**: Expands and improves queries for better retrieval
2. **Reranking**: Uses cross-encoder models to rerank retrieved passages for better relevance

## Key Features:
- Query expansion and rewriting
- Cross-encoder reranking for better passage selection
- Multiprocessing for faster processing
- Resume capability for interrupted runs


In [2]:
# Load Pre-built Database and Data

# **Prerequisites**: Run `data_setup.ipynb` first to create the database and embeddings.

# Load pre-built data and database
print("Loading pre-built database and data for Enhanced RAG...")

# Load queries from saved CSV
queries = pd.read_csv("../data/processed/queries.csv")
print(f"Loaded {len(queries)} queries")

# Load pre-computed query embeddings
query_embeddings = np.load("../data/processed/query_embeddings.npy")
print(f"Loaded query embeddings: {query_embeddings.shape}")

# Initialize embedding model (for consistency)
embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
print("Embedding model loaded")

# Connect to existing Milvus database
client = MilvusClient("../data/processed/rag_wikipedia_mini.db")
print("Connected to existing Milvus database")

Loading pre-built database and data for Enhanced RAG...
Loaded 918 queries
Loaded query embeddings: (918, 384)
Embedding model loaded


  from pkg_resources import DistributionNotFound, get_distribution


Connected to existing Milvus database


In [3]:
# Initialize Enhanced RAG Components

# 1. Query Rewriting Model (T5-based for query expansion)
query_rewriter_model = "google/flan-t5-base"
query_rewriter_tokenizer = AutoTokenizer.from_pretrained(query_rewriter_model)
query_rewriter = AutoModelForSeq2SeqLM.from_pretrained(query_rewriter_model, dtype=torch.float32)

# 2. Reranking Model (Cross-encoder for passage reranking)
reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')

# 3. Generation Model (for final answer generation)
generation_model_name = "google/flan-t5-base"
generation_tokenizer = AutoTokenizer.from_pretrained(generation_model_name)
generation_model = AutoModelForSeq2SeqLM.from_pretrained(generation_model_name, dtype=torch.float32)

print("Enhanced RAG components loaded:")
print(f"- Query rewriter: {query_rewriter_model}")
print(f"- Reranker: cross-encoder/ms-marco-MiniLM-L-6-v2")
print(f"- Generation model: {generation_model_name}")


Enhanced RAG components loaded:
- Query rewriter: google/flan-t5-base
- Reranker: cross-encoder/ms-marco-MiniLM-L-6-v2
- Generation model: google/flan-t5-base


In [4]:
# Enhanced RAG Functions

def rewrite_query(original_query: str, tokenizer, model) -> str:
    """
    Rewrite and expand the original query for better retrieval
    
    Args:
        original_query: The original user query
        tokenizer: Query rewriter tokenizer
        model: Query rewriter model
    
    Returns:
        str: Rewritten/expanded query
    """
    # Create a prompt for query expansion
    expansion_prompt = f"""Rewrite and expand this question to make it more specific and searchable. 
    Include relevant keywords and context that would help find better information.
    
    Original question: {original_query}
    
    Expanded question:"""
    
    try:
        inputs = tokenizer(expansion_prompt, return_tensors="pt", max_length=512, truncation=True)
        
        with torch.no_grad():
            outputs = model.generate(
                inputs.input_ids,
                max_length=100,
                num_beams=3,
                early_stopping=True,
                do_sample=False
            )
        
        expanded_query = tokenizer.decode(outputs[0], skip_special_tokens=True)
        
        # Clean up the output
        expanded_query = expanded_query.strip()
        
        # If expansion failed or is too short, return original
        if len(expanded_query) < len(original_query) * 0.5:
            return original_query
            
        return expanded_query
        
    except Exception as e:
        print(f"Query rewriting failed: {e}")
        return original_query

def rerank_passages(query: str, passages: List[str], reranker_model, top_k: int = 3) -> List[Tuple[str, float]]:
    """
    Rerank retrieved passages using a cross-encoder model
    
    Args:
        query: The search query
        passages: List of retrieved passages
        reranker_model: Cross-encoder reranking model
        top_k: Number of top passages to return
    
    Returns:
        List of tuples: (passage, relevance_score)
    """
    if not passages:
        return []
    
    try:
        # Create query-passage pairs for reranking
        pairs = [(query, passage) for passage in passages]
        
        # Get relevance scores from cross-encoder
        scores = reranker_model.predict(pairs)
        
        # Combine passages with scores and sort by relevance
        passage_scores = list(zip(passages, scores))
        passage_scores.sort(key=lambda x: x[1], reverse=True)
        
        # Return top-k passages
        return passage_scores[:top_k]
        
    except Exception as e:
        print(f"Reranking failed: {e}")
        # Return original passages if reranking fails
        return [(passage, 0.0) for passage in passages[:top_k]]

def search_and_fetch_top_n_passages(query_emb, limit=10):
    """
    Search for similar passages in the vector database (same as naive RAG)
    """
    search_params = {
        "metric_type": "COSINE",
        "params": {"nprobe": 10}
    }
    
    output_ = client.search(
        collection_name="rag_mini",
        data=[query_emb.tolist()],
        anns_field="embedding",
        search_params=search_params,
        limit=limit,
        output_fields=["passage"]
    )
    return output_

print("Enhanced RAG functions defined:")
print("- rewrite_query(): Expands queries for better retrieval")
print("- rerank_passages(): Reranks passages using cross-encoder")
print("- search_and_fetch_top_n_passages(): Vector search function")


Enhanced RAG functions defined:
- rewrite_query(): Expands queries for better retrieval
- rerank_passages(): Reranks passages using cross-encoder
- search_and_fetch_top_n_passages(): Vector search function


In [5]:
# Enhanced RAG Pipeline with Multiprocessing

def process_single_enhanced_query(args):
    """
    Process a single query through the enhanced RAG pipeline
    
    Args:
        args: tuple containing (query_idx, question, embedding, client, embedding_model, 
              query_rewriter_tokenizer, query_rewriter, reranker, generation_tokenizer, 
              generation_model, system_prompt, n)
    
    Returns:
        tuple: (query_idx, generated_answer, combined_context, success, skipped)
    """
    (query_idx, question, embedding, client, embedding_model, 
     query_rewriter_tokenizer, query_rewriter, reranker, 
     generation_tokenizer, generation_model, system_prompt, n, existing_answers) = args
    
    # Check if query already has an answer
    if query_idx < len(existing_answers) and existing_answers[query_idx] and existing_answers[query_idx] != "":
        return (query_idx, existing_answers[query_idx], "", True, True)  # skipped=True
    
    try:
        # Step 1: Query Rewriting
        rewritten_query = rewrite_query(question, query_rewriter_tokenizer, query_rewriter)
        
        # Step 2: Generate embedding for rewritten query
        rewritten_embedding = embedding_model.encode([rewritten_query])[0]
        
        # Step 3: Retrieve more passages (we'll rerank them)
        search_results = search_and_fetch_top_n_passages(rewritten_embedding, limit=10)
        
        # Extract passages
        retrieved_passages = []
        for i in range(len(search_results[0])):
            retrieved_passages.append(search_results[0][i]['entity']['passage'])
        
        # Step 4: Rerank passages using cross-encoder
        reranked_passages = rerank_passages(rewritten_query, retrieved_passages, reranker, top_k=n)
        
        # Extract top reranked passages
        top_passages = [passage for passage, score in reranked_passages]
        combined_context = "\n\n".join(top_passages)
        
        # Step 5: Generate answer with enhanced context
        prompt = f"""{system_prompt}\n
        Context: {combined_context}\n
        Question: {question}"""
        
        inputs = generation_tokenizer(prompt, return_tensors="pt", max_length=1024, truncation=True)
        
        with torch.no_grad():
            outputs = generation_model.generate(
                inputs.input_ids,
                max_length=150,
                num_beams=4,
                early_stopping=True,
                do_sample=False
            )
        
        # Decode the generated answer
        answer = generation_tokenizer.decode(outputs[0], skip_special_tokens=True)
        
        # Clear memory
        del inputs, outputs
        gc.collect()
        
        return (query_idx, answer, combined_context, True, False)  # skipped=False
        
    except Exception as e:
        print(f"Error processing enhanced query {query_idx + 1}: {e}")
        return (query_idx, "Error generating answer", "", False, False)

def process_enhanced_queries_parallel(queries, query_embeddings, client, embedding_model, 
                                    query_rewriter_tokenizer, query_rewriter, reranker,
                                    generation_tokenizer, generation_model, system_prompt, 
                                    n=3, max_workers=4, save_interval=50):
    """
    Process all queries using enhanced RAG pipeline with multiprocessing
    
    Returns:
        tuple: (generated_answers, combined_contexts, success_count, skipped_count)
    """
    print(f"🚀 Starting Enhanced RAG processing with {max_workers} workers...")
    print(f"Processing {len(queries)} queries with query rewriting and reranking...")
    
    # Check for existing results
    existing_answers = []
    existing_contexts = []
    
    if 'generated_answer' in queries.columns:
        existing_answers = queries['generated_answer'].fillna("").tolist()
        print(f"Found {sum(1 for ans in existing_answers if ans and ans != '')} existing answers")
    
    if 'combined_context' in queries.columns:
        existing_contexts = queries['combined_context'].fillna("").tolist()
    
    # Prepare arguments for each query
    query_args = []
    for idx, (question, embedding) in enumerate(zip(queries['question'], query_embeddings)):
        query_args.append((idx, question, embedding, client, embedding_model,
                          query_rewriter_tokenizer, query_rewriter, reranker,
                          generation_tokenizer, generation_model, system_prompt, n, existing_answers))
    
    # Initialize results storage
    generated_answers = existing_answers[:] if existing_answers else [""] * len(queries)
    combined_contexts = existing_contexts[:] if existing_contexts else [""] * len(queries)
    success_count = 0
    skipped_count = 0
    
    # Process queries in parallel
    start_time = time.time()
    
    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        # Submit all tasks
        future_to_idx = {executor.submit(process_single_enhanced_query, args): args[0] for args in query_args}
        
        # Process completed tasks with progress bar
        with tqdm(total=len(queries), desc="Enhanced RAG Processing") as pbar:
            for i, future in enumerate(as_completed(future_to_idx)):
                query_idx, answer, context, success, skipped = future.result()
                
                # Store results
                generated_answers[query_idx] = answer
                if context:  # Only update context if we have a new one
                    combined_contexts[query_idx] = context
                
                if success:
                    success_count += 1
                if skipped:
                    skipped_count += 1
                
                pbar.update(1)
                
                # Periodic saving to prevent data loss
                if (i + 1) % save_interval == 0:
                    temp_df = queries.copy()
                    temp_df['generated_answer'] = generated_answers
                    temp_df['combined_context'] = combined_contexts
                    temp_df.to_csv("../results/enhanced_rag_generated_answers_temp.csv", index=False)
                    print(f"\n💾 Enhanced RAG progress saved at query {i + 1}")
    
    end_time = time.time()
    processing_time = end_time - start_time
    
    print(f"\n🎉 Enhanced RAG processing completed!")
    print(f"Total time: {processing_time:.2f} seconds")
    print(f"Average time per query: {processing_time/len(queries):.2f} seconds")
    print(f"Successful queries: {success_count}/{len(queries)}")
    print(f"Skipped queries: {skipped_count}/{len(queries)}")
    print(f"Features used: Query Rewriting + Reranking + Multiprocessing")
    
    return generated_answers, combined_contexts, success_count, skipped_count

print("Enhanced RAG pipeline functions defined:")
print("- process_single_enhanced_query(): Single query processing with rewriting + reranking")
print("- process_enhanced_queries_parallel(): Parallel processing with all enhancements")


Enhanced RAG pipeline functions defined:
- process_single_enhanced_query(): Single query processing with rewriting + reranking
- process_enhanced_queries_parallel(): Parallel processing with all enhancements


In [6]:
# Test Enhanced RAG on a Single Query
print("🧪 Testing Enhanced RAG on a single query...")

# Test query
test_query = queries['question'].iloc[0]
test_embedding = embedding_model.encode([test_query])[0]

print(f"Original query: {test_query}")

# Step 1: Query Rewriting
rewritten_query = rewrite_query(test_query, query_rewriter_tokenizer, query_rewriter)
print(f"Rewritten query: {rewritten_query}")

# Step 2: Search with rewritten query
rewritten_embedding = embedding_model.encode([rewritten_query])[0]
search_results = search_and_fetch_top_n_passages(rewritten_embedding, limit=5)

# Extract passages
retrieved_passages = []
for i in range(len(search_results[0])):
    retrieved_passages.append(search_results[0][i]['entity']['passage'])

print(f"\nRetrieved {len(retrieved_passages)} passages")

# Step 3: Rerank passages
reranked_passages = rerank_passages(rewritten_query, retrieved_passages, reranker, top_k=3)

print(f"\nReranked passages:")
for i, (passage, score) in enumerate(reranked_passages):
    print(f"Passage {i+1} (score: {score:.3f}): {passage[:100]}...")

# Step 4: Generate answer
system_prompt = """You are a helpful assistant that answers questions based on the provided context. 
Use only the information from the context to answer the question. 
If the context doesn't contain enough information to answer the question, say so.
Be concise and accurate in your response."""

top_passages = [passage for passage, score in reranked_passages]
combined_context = "\n\n".join(top_passages)

prompt = f"""{system_prompt}\n
Context: {combined_context}\n
Question: {test_query}"""

inputs = generation_tokenizer(prompt, return_tensors="pt", max_length=1024, truncation=True)

with torch.no_grad():
    outputs = generation_model.generate(
        inputs.input_ids,
        max_length=150,
        num_beams=4,
        early_stopping=True,
        do_sample=False
    )

answer = generation_tokenizer.decode(outputs[0], skip_special_tokens=True)

print(f"\n🎯 Enhanced RAG Result:")
print(f"Question: {test_query}")
print(f"Answer: {answer}")
print(f"Context length: {len(combined_context)} characters")

# Clear memory
del inputs, outputs
gc.collect()


🧪 Testing Enhanced RAG on a single query...
Original query: Was Abraham Lincoln the sixteenth President of the United States?
Rewritten query: Was Abraham Lincoln the sixteenth President of the United States?

Retrieved 5 passages

Reranked passages:
Passage 1 (score: 9.051): Abraham Lincoln (February 12, 1809 â April 15, 1865) was the sixteenth President of the United Sta...
Passage 2 (score: 7.637): On November 6, 1860, Lincoln was elected as the 16th President of the United States, beating Democra...
Passage 3 (score: -1.837): Young Abraham Lincoln...

🎯 Enhanced RAG Result:
Question: Was Abraham Lincoln the sixteenth President of the United States?
Answer: yes.
Context length: 2559 characters


91

In [7]:
# Process All Queries with Enhanced RAG
print("🚀 Starting Enhanced RAG processing for all queries...")

# Check for existing temporary file (crash recovery)
import os
temp_file_path = "../results/enhanced_rag_generated_answers_temp.csv"
if os.path.exists(temp_file_path):
    print(f"🔄 Found existing temporary file: {temp_file_path}")
    print("Loading existing progress to resume processing...")
    temp_queries = pd.read_csv(temp_file_path)
    
    # Check if we have any existing results
    if 'generated_answer' in temp_queries.columns:
        existing_count = sum(1 for ans in temp_queries['generated_answer'].fillna("") if ans and ans != "")
        print(f"Found {existing_count} existing answers in temporary file")
        
        # Update queries with existing results
        queries = temp_queries.copy()
        print("Resuming from existing progress...")
    else:
        print("Temporary file found but no valid results, starting fresh...")
else:
    print("No existing temporary file found, starting fresh...")

# Generate query embeddings
query_embeddings = embedding_model.encode(queries['question'].tolist())
print(f"Generated embeddings for {len(query_embeddings)} queries")

# System prompt
system_prompt = """You are a helpful assistant that answers questions based on the provided context. 
Use only the information from the context to answer the question. 
If the context doesn't contain enough information to answer the question, say so.
Be concise and accurate in your response."""

# Enhanced RAG parameters
n = 3  # Use top 3 reranked contexts
max_workers = 4  # Number of parallel workers
save_interval = 50  # Save progress every 50 queries

# Process all queries with enhanced RAG
generated_answers, combined_contexts, success_count, skipped_count = process_enhanced_queries_parallel(
    queries=queries,
    query_embeddings=query_embeddings,
    client=client,
    embedding_model=embedding_model,
    query_rewriter_tokenizer=query_rewriter_tokenizer,
    query_rewriter=query_rewriter,
    reranker=reranker,
    generation_tokenizer=generation_tokenizer,
    generation_model=generation_model,
    system_prompt=system_prompt,
    n=n,
    max_workers=max_workers,
    save_interval=save_interval
)

# Add results to the queries dataframe
queries['generated_answer'] = generated_answers
queries['combined_context'] = combined_contexts

print(f"\n🎉 Enhanced RAG Results:")
print(f"Total queries processed: {len(queries)}")
print(f"Successful queries: {success_count}")
print(f"Skipped queries: {skipped_count}")
print(f"Success rate: {success_count/len(queries)*100:.1f}%")
print(f"Enhanced features: Query Rewriting + Reranking + Multiprocessing")


🚀 Starting Enhanced RAG processing for all queries...
No existing temporary file found, starting fresh...
Generated embeddings for 918 queries
🚀 Starting Enhanced RAG processing with 4 workers...
Processing 918 queries with query rewriting and reranking...


Enhanced RAG Processing:   5%|▌         | 50/918 [06:30<51:29,  3.56s/it]  


💾 Enhanced RAG progress saved at query 50


Enhanced RAG Processing:  11%|█         | 100/918 [12:26<1:21:57,  6.01s/it]


💾 Enhanced RAG progress saved at query 100


Enhanced RAG Processing:  16%|█▋        | 150/918 [17:40<48:32,  3.79s/it]  


💾 Enhanced RAG progress saved at query 150


Enhanced RAG Processing:  22%|██▏       | 200/918 [24:10<1:44:29,  8.73s/it]


💾 Enhanced RAG progress saved at query 200


Enhanced RAG Processing:  27%|██▋       | 250/918 [27:32<45:48,  4.12s/it]  


💾 Enhanced RAG progress saved at query 250


Enhanced RAG Processing:  33%|███▎      | 300/918 [32:00<50:54,  4.94s/it]  


💾 Enhanced RAG progress saved at query 300


Enhanced RAG Processing:  38%|███▊      | 350/918 [37:36<1:25:18,  9.01s/it]


💾 Enhanced RAG progress saved at query 350


Enhanced RAG Processing:  44%|████▎     | 401/918 [41:52<30:42,  3.56s/it]  


💾 Enhanced RAG progress saved at query 400


Enhanced RAG Processing:  49%|████▉     | 450/918 [47:08<1:21:18, 10.43s/it]


💾 Enhanced RAG progress saved at query 450


Enhanced RAG Processing:  54%|█████▍    | 500/918 [52:05<26:30,  3.81s/it]  


💾 Enhanced RAG progress saved at query 500


Enhanced RAG Processing:  60%|██████    | 551/918 [57:37<53:56,  8.82s/it]  


💾 Enhanced RAG progress saved at query 550


Enhanced RAG Processing:  65%|██████▌   | 600/918 [1:03:34<30:08,  5.69s/it]  


💾 Enhanced RAG progress saved at query 600


Enhanced RAG Processing:  71%|███████   | 651/918 [1:11:15<23:50,  5.36s/it]  


💾 Enhanced RAG progress saved at query 650


Enhanced RAG Processing:  76%|███████▋  | 700/918 [1:17:57<19:09,  5.27s/it]  


💾 Enhanced RAG progress saved at query 700


Enhanced RAG Processing:  82%|████████▏ | 750/918 [1:24:22<15:42,  5.61s/it]


💾 Enhanced RAG progress saved at query 750


Enhanced RAG Processing:  87%|████████▋ | 800/918 [1:31:57<22:55, 11.66s/it]


💾 Enhanced RAG progress saved at query 800


Enhanced RAG Processing:  93%|█████████▎| 850/918 [1:40:06<10:39,  9.41s/it]


💾 Enhanced RAG progress saved at query 850


Enhanced RAG Processing:  98%|█████████▊| 900/918 [1:49:38<02:05,  6.95s/it]


💾 Enhanced RAG progress saved at query 900


Enhanced RAG Processing: 100%|██████████| 918/918 [1:54:44<00:00,  7.50s/it]



🎉 Enhanced RAG processing completed!
Total time: 6884.45 seconds
Average time per query: 7.50 seconds
Successful queries: 918/918
Skipped queries: 0/918
Features used: Query Rewriting + Reranking + Multiprocessing

🎉 Enhanced RAG Results:
Total queries processed: 918
Successful queries: 918
Skipped queries: 0
Success rate: 100.0%
Enhanced features: Query Rewriting + Reranking + Multiprocessing


In [8]:
# Save Enhanced RAG Results
print("💾 Saving Enhanced RAG results...")

# Save to CSV
queries.to_csv("../results/enhanced_rag_answers.csv", index=False)
print("Enhanced RAG results saved to: ../results/enhanced_rag_answers.csv")

# Display sample results
print("\n📊 Sample Enhanced RAG Results:")
display_columns = ['question', 'answer', 'generated_answer', 'combined_context']
sample_results = queries[display_columns].head(5)

for idx, row in sample_results.iterrows():
    print(f"\n--- Query {idx + 1} ---")
    print(f"Question: {row['question']}")
    print(f"Ground Truth: {row['answer']}")
    print(f"Enhanced RAG Answer: {row['generated_answer']}")
    print(f"Context Length: {len(row['combined_context'])} characters")
    print(f"Context Preview: {row['combined_context'][:150]}...")

print(f"\n✅ Enhanced RAG processing completed!")
print(f"Results saved with query rewriting and reranking improvements")


💾 Saving Enhanced RAG results...
Enhanced RAG results saved to: ../results/enhanced_rag_answers.csv

📊 Sample Enhanced RAG Results:

--- Query 1 ---
Question: Was Abraham Lincoln the sixteenth President of the United States?
Ground Truth: yes
Enhanced RAG Answer: yes.
Context Length: 2559 characters
Context Preview: Abraham Lincoln (February 12, 1809 â April 15, 1865) was the sixteenth President of the United States, serving from March 4, 1861 until his assassin...

--- Query 2 ---
Question: Did Lincoln sign the National Banking Act of 1863?
Ground Truth: yes
Enhanced RAG Answer: Yes.
Context Length: 2246 characters
Context Preview: Lincoln believed in the Whig theory of the presidency, which left Congress to write the laws while he signed them, vetoing only those bills that threa...

--- Query 3 ---
Question: Did his mother die of pneumonia?
Ground Truth: no
Enhanced RAG Answer: No.
Context Length: 1184 characters
Context Preview: Alice Hathaway Lee Roosevelt (July 29, 1861 in Chest

In [9]:
# Cleanup temporary files
import os

# Remove temporary file if exists
temp_file_path = "../results/enhanced_rag_generated_answers_temp.csv"
if os.path.exists(temp_file_path):
    os.remove(temp_file_path)
    print(f"🗑️ Removed temporary file: {temp_file_path}")
else:
    print("No temporary file to clean up")


🗑️ Removed temporary file: ../results/enhanced_rag_generated_answers_temp.csv
