In [None]:
# ============================================================
# Enhanced RAG (Step 5): Query Rewriting → Retrieval → Reranking → Generation
# Baseline choices aligned with Step 3/4 findings:
#   - Prompt: Persona
#   - Embedding: all-mpnet-base-v2 (dim768)
#   - K (final context size): 5
#   - Two Enhancements: Query Rewriting + Reranking (Cross-Encoder)
# ============================================================

In [None]:
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [1]:
import os
import time
import pandas as pd
from tqdm import tqdm

from sentence_transformers import SentenceTransformer, CrossEncoder
from transformers import pipeline
from pymilvus import MilvusClient, FieldSchema, CollectionSchema, DataType



In [2]:
# for ranking
from sentence_transformers import CrossEncoder 
from typing import List, Dict, Tuple

In [3]:
# Paths
INPUT_QA_PATH   = "../data/evaluation/all_experiment_generations.csv"  # has question & answer
PASSAGES_PATH   = "../data/processed/rag_mini_wiki.csv"                 # passages to index

print("Reading QA file from:", os.path.abspath(INPUT_QA_PATH))
print("Passages file:", os.path.abspath(PASSAGES_PATH))

Reading QA file from: /Users/connie/Desktop/Fall 2025/LLM/Assignment2/data/evaluation/all_experiment_generations.csv
Passages file: /Users/connie/Desktop/Fall 2025/LLM/Assignment2/data/processed/rag_mini_wiki.csv


## Load QA

In [5]:
qa_df = pd.read_csv(INPUT_QA_PATH)

# Keep only what we need for enhanced generation
qa_df = qa_df[["question", "answer"]].copy()
print("QA shape:", qa_df.shape)
qa_df.head()

QA shape: (120, 2)


Unnamed: 0,question,answer
0,What did James Monroe make in 1817?,two long tours
1,Are Gray Wolves native to North America?,Yes
2,Is English the official language?,yes
3,How long is the elephant's gestation period?,22 months
4,Are diving ducks heavier tha dabbling ducks?,Yes


## Models & Milvus Client (new DB)

In [7]:
# Query Rewriting model
query_rewriter = pipeline("text2text-generation", model="google/flan-t5-base")

# Embedding model
EMBEDDING_NAME = "all-mpnet-base-v2"
embedding_model = SentenceTransformer(EMBEDDING_NAME)

# Cross-encoder Reranker
RERANKER_NAME = "cross-encoder/ms-marco-MiniLM-L-6-v2"
reranker = CrossEncoder(RERANKER_NAME)

# Generation model
generator = pipeline("text2text-generation", model="google/flan-t5-base")

# 使用新的 Milvus Lite DB
ENHANCED_DB_NAME = "rag_wikipedia_mini_enhanced.db"
client = MilvusClient(ENHANCED_DB_NAME)
print("Milvus DB:", os.path.abspath(ENHANCED_DB_NAME))

Device set to use mps:0
Device set to use mps:0


Milvus DB: /Users/connie/Desktop/Fall 2025/LLM/Assignment2/src/rag_wikipedia_mini_enhanced.db


In [8]:
def rebuild_milvus_index(client: MilvusClient, passages_csv_path: str, embedding_model: SentenceTransformer, model_name_str: str):
    """
    Build a fresh Milvus collection ('rag_mini') inside the NEW enhanced DB
    with embeddings computed by the given embedding_model.
    """
    print(f"\n[Build Milvus] Rebuilding collection with {model_name_str} ...")
    df_passages = pd.read_csv(passages_csv_path)
    assert "passages" in df_passages.columns, "Expected a 'passages' column in passages CSV."

    # 1) Encode
    t0 = time.time()
    embeddings = embedding_model.encode(df_passages["passages"].tolist(), batch_size=64)
    vector_dim = embeddings.shape[1]
    print(f"Encoded {len(df_passages)} passages → dim={vector_dim} in {time.time()-t0:.1f}s")

    # 2) Define schema
    id_field        = FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=False)
    passage_field   = FieldSchema(name="passage", dtype=DataType.VARCHAR, max_length=3000)
    embedding_field = FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=vector_dim)
    schema = CollectionSchema(fields=[id_field, passage_field, embedding_field], description="Enhanced passages")

    # 3) Drop & recreate
    if client.has_collection("rag_mini"):
        client.drop_collection("rag_mini")
    client.create_collection(collection_name="rag_mini", schema=schema, consistency_level="Strong")

    # 4) Insert
    rag_data = [{"id": i, "passage": df_passages.iloc[i]["passages"], "embedding": embeddings[i].tolist()}
                for i in range(len(df_passages))]
    client.insert(collection_name="rag_mini", data=rag_data)

    # 5) Index
    index_params = MilvusClient.prepare_index_params()
    index_params.add_index(
        field_name="embedding",
        index_type="IVF_FLAT",       
        metric_type="COSINE",
        params={"nlist": 128}
    )
    client.create_index("rag_mini", index_params=index_params)
    client.load_collection("rag_mini")

    # 6) Report
    stats = client.get_collection_stats("rag_mini")
    print("Entity count:", stats["row_count"])
    print("[Build Milvus] Done.")

# Build enhanced DB once
rebuild_milvus_index(client, PASSAGES_PATH, embedding_model, EMBEDDING_NAME)


[Build Milvus] Rebuilding collection with all-mpnet-base-v2 ...


I0000 00:00:1759172786.717736 11151799 fork_posix.cc:71] Other threads are currently calling into gRPC, skipping fork() handlers


Encoded 3200 passages → dim=768 in 49.9s
Entity count: 3200
[Build Milvus] Done.


In [9]:
TOP_N = 10   
TOP_K = 5    

def rewrite_query(query: str) -> str:
    """Rewrite user query into clearer, retrieval-friendly text."""
    prompt = f"""
    You are a query rewriter for retrieval systems.
    Task: Rewrite the following question into a clearer, concise form. 
    Rules:
    - Do NOT answer the question.
    - Do NOT change the meaning of the question.
    - Keep named entities and dates intact.
    - Keep the rewritten query in a question format.

    Original Question: {query}
    Rewritten Query:
    """
    out = query_rewriter(prompt, max_new_tokens=64, do_sample=False)[0]["generated_text"]
    return out.strip()

def retrieve_and_rerank(query: str, orig_query: str, top_n: int = TOP_N, top_k: int = TOP_K):
    """Dense retrieval with rewritten query + rerank using cross-encoder against original query."""
    q_emb = embedding_model.encode(query).tolist()
    results = client.search(
        collection_name="rag_mini",
        data=[q_emb],
        anns_field="embedding",
        search_params={"metric_type": "COSINE", "params": {"nprobe": 10}},
        limit=top_n,
        output_fields=["id", "passage"]
    )
    passages = [hit["entity"]["passage"] for hit in results[0]]
    if not passages:
        return []

    # Cross-encoder: rerank passages w.r.t. original query
    pairs  = [[orig_query, p] for p in passages]
    scores = reranker.predict(pairs)
    ranked = sorted(zip(passages, scores), key=lambda x: x[1], reverse=True)
    return [p for p, _ in ranked[:top_k]]


PERSONA_PROMPT = """You are a historical expert known for concise and authoritative answers.
Answer ONLY using the provided context. If the context does not contain the answer, say "Information not available in the context."
Keep the answer short and factual.
"""

def generate_answer_with_persona(orig_query: str, passages: list[str]) -> str:
    context = "\n---\n".join(passages) if passages else ""
    prompt = f"""{PERSONA_PROMPT}

Context:
{context}

Question: {orig_query}
"""
    out = generator(prompt, max_new_tokens=200, do_sample=False)
    return out[0]["generated_text"]

In [10]:
enhanced_answers   = []
rewritten_queries  = []
used_passage_sizes = []
contexts_joined    = []   

t0 = time.time()
for i in tqdm(range(len(qa_df)), desc="Enhanced RAG"):
    orig_q = qa_df.iloc[i]["question"]

    # 1) Query Rewriting
    rew_q = rewrite_query(orig_q)
    rewritten_queries.append(rew_q)

    # 2) Retrieval + Reranking
    top_passages = retrieve_and_rerank(rew_q, orig_q, top_n=TOP_N, top_k=TOP_K)
    used_passage_sizes.append(len(top_passages))
    contexts_joined.append("\n---\n".join(top_passages))  # 🔥 把 top-k passages 串接起來

    # 3) Persona-based Generation
    ans = generate_answer_with_persona(orig_q, top_passages)
    enhanced_answers.append(ans)

print(f"Total time: {time.time()-t0:.1f}s for {len(qa_df)} queries")

qa_df["rewritten_query"] = rewritten_queries
qa_df["gen_enhanced_k5_persona_mpnet"] = enhanced_answers
qa_df["ctx_count_used"] = used_passage_sizes
qa_df["contexts"] = contexts_joined 

Enhanced RAG:   0%|                                     | 0/120 [00:00<?, ?it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (726 > 512). Running this sequence through the model will result in indexing errors
Enhanced RAG: 100%|███████████████████████████| 120/120 [04:28<00:00,  2.24s/it]

Total time: 268.5s for 120 queries





In [16]:
out_path = os.path.join("../data/evaluation", "enhanced_experiment_generations.csv")
qa_df.to_csv(out_path, index=False)
print(f"✅ Enhanced results saved → {os.path.abspath(out_path)}")

✅ Enhanced results saved → /Users/connie/Desktop/Fall 2025/LLM/Assignment2/data/evaluation/enhanced_experiment_generations.csv


### Enhanced RAG pipeline

In [15]:
def enhanced_rag_pipeline(query: str, client: MilvusClient, embedding_model: SentenceTransformer, generator, reranker: CrossEncoder, 
                          k_retrieval: int, k_rerank: int, strategy_template: str) -> Tuple[str, List[str]]:
    """
    Executes the Enhanced RAG workflow (Q.R. -> Multi-Query Retrieval -> Reranking -> Generation).
    Returns (generated_answer, final_context_list) for RAGAs evaluation.
    """
    
    # 1. Query Rewriting
    all_queries = generate_rewritten_queries(query, generator)
    
    # 2. Multi-Query Retrieval (Retrieving a large set for high recall)
    all_results = []
    
    for q in all_queries:
        q_emb = embedding_model.encode(q).tolist()
        results = client.search(
            collection_name="rag_mini",
            data=[q_emb],
            anns_field="embedding",
            search_params={"metric_type": "COSINE", "params": {"nprobe": 10}},
            limit=k_retrieval, # Use K_RETRIEVAL_PRE_RERANK (e.g., 10)
            output_fields=["passage"]
        )
        passages = [hit["entity"]["passage"] for hit in results[0]]
        all_results.extend(passages)
        
    # Remove duplicates from the large set
    unique_passages = list(set(all_results)) 

    # 3. Reranking (Filtering out noise to maintain Faithfulness)
    if not unique_passages:
        return "Information not available in the context.", ["No context retrieved."]
        
    final_context_list = rerank_passages(query, unique_passages, reranker, k_rerank)

    # 4. Context Combination for LLM
    context_for_llm = "\n---\n".join(final_context_list) 

    # 5. Generation
    prompt = strategy_template.format(context=context_for_llm, question=query)
    out = generator(prompt, max_new_tokens=200, do_sample=False)
    final_answer = out[0]["generated_text"]
    
    return final_answer, final_context_list


def run_enhanced_experiment(qa_df_subset: pd.DataFrame):
    """
    Runs the Enhanced RAG Pipeline across all subset queries and stores results.
    """
    if qa_df_subset.empty:
        print("Error: Subset DataFrame is empty. Cannot run experiment.")
        return pd.DataFrame()
        
    print(f"\n--- Starting Enhanced RAG Pipeline for {len(qa_df_subset)} Queries (Step 5/6) ---")
    
    generated_answers = []
    retrieved_contexts = []
    strategy_template = PROMPT_NAIVE # Using the fixed strategy (Naive Prompt) for this experiment
    
    for i in tqdm(range(len(qa_df_subset)), desc="Enhanced RAG (Q.R. + Rerank)"):
        q = qa_df_subset.iloc[i]["question"]
        
        # Execute the Enhanced RAG workflow
        answer, context_list = enhanced_rag_pipeline(
            q, CLIENT, EMBEDDING_MODEL, GENERATOR, RERANKER, 
            K_RETRIEVAL_PRE_RERANK, K_RERANK_FINAL, strategy_template
        )
        
        generated_answers.append(answer)
        retrieved_contexts.append(context_list) # Store context as List[str] for RAGAs

    # Prepare results DataFrame
    results_df = qa_df_subset.copy()
    results_df["enhanced_answer"] = generated_answers
    # Store contexts as string representation of a list; will be converted back for RAGAs
    results_df["enhanced_contexts"] = [str(ctx) for ctx in retrieved_contexts] 
    
    return results_df

In [16]:
if __name__ == "__main__":
    
    # Run the enhanced experiment
    enhanced_results_df = run_enhanced_experiment(qa_df_subset)
    
    # Save the results for Step 6 RAGAs evaluation
    os.makedirs("../data/evaluation", exist_ok=True)
    save_path = "../data/evaluation/enhanced_rag_generations.csv"
    enhanced_results_df.to_csv(save_path, index=False)
    
    print("\nEnhanced RAG experiment complete.")
    print(f"Results saved to {save_path} for Step 6 RAGAs evaluation.")


--- Starting Enhanced RAG Pipeline for 120 Queries (Step 5/6) ---


Enhanced RAG (Q.R. + Rerank):   0%|                     | 0/120 [00:00<?, ?it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (575 > 512). Running this sequence through the model will result in indexing errors
Enhanced RAG (Q.R. + Rerank):  19%|██▎         | 23/120 [01:13<05:11,  3.21s/it]


KeyboardInterrupt: 