# RAG

Implement a base RAG module in DSPy. 
Given a question, retrieve the top-k documents in a list of HTML documents, then pass them as context to an LLM.

Refer to https://dspy.ai/tutorials/rag/. 


In [1]:
import dspy
from sentence_transformers import SentenceTransformer

# Load an extremely efficient local model for retrieval
model = SentenceTransformer("sentence-transformers/static-retrieval-mrl-en-v1", device="gpu")

# Create an embedder using the model's encode method
embedder = dspy.Embedder(model.encode)

# Traverse a directory and read html files - extract text from the html files
import os
from bs4 import BeautifulSoup
def read_html_files(directory):
    texts = []
    for filename in os.listdir(directory):
        if filename.endswith(".html"):
            with open(os.path.join(directory, filename), 'r', encoding='utf-8') as file:
                soup = BeautifulSoup(file, 'html.parser')
                texts.append(soup.get_text())
    return texts

ModuleNotFoundError: No module named 'dspy'

In [None]:
corpus = read_html_files("../PragmatiCQA-sources/The Legend of Zelda")
print(f"Loaded {len(corpus)} documents. Will encode them below.")

In [None]:
# Parameters for the retriever
max_characters = 10000  # for truncating >99th percentile of documents
topk_docs_to_retrieve = 5  # number of documents to retrieve per search query

search = dspy.retrievers.Embeddings(embedder=embedder, corpus=corpus, k=topk_docs_to_retrieve)



In [None]:
# lm = dspy.LM('ollama_chat/devstral', api_base='http://localhost:11434', api_key='')
lm = dspy.LM('xai/grok-3-mini')
dspy.configure(lm=lm)

In [None]:
class RAG(dspy.Module):
    def __init__(self):
        self.respond = dspy.ChainOfThought('context, question -> response')

    def forward(self, question):
        context = search(question).passages
        return self.respond(context=context, question=question)
    
rag = RAG()

In [None]:
answer = rag(question="What is the main plot of The Legend of Zelda?")  # Example query

print(answer.response)  # Print the response from the RAG model

In [None]:
q = 'What year did the Legend of Zelda come out?' 

print(rag(question=q).response)

Part 4.3 — Traditional QA baseline (start)

Plan:
- Use the existing retriever `search` to obtain retrieved passages for a question.
- Use Hugging Face's 'distilbert-base-cased-distilled-squad' extractive QA pipeline to answer the question given:
  1) Literal spans (from dataset),
  2) Pragmatic spans (from dataset),
  3) Retrieved context (from `search`).
- Evaluate these three configurations with dspy.evaluate.SemanticF1 on the first question of each conversation.

In [None]:
# Loads validation data, builds HF QA pipeline, runs example predictions and computes SemanticF1.
from transformers import pipeline
import json, os
import dspy
from dspy.evaluate import SemanticF1

# Helper to load jsonl (repeat if not already present in this notebook)
def read_data(filename, dataset_dir="../PragmatiCQA/data"):
    corpus = []
    with open(os.path.join(dataset_dir, filename), 'r', encoding='utf-8') as f:
        for line in f:
            corpus.append(json.loads(line))
    return corpus

# Load validation set (first-question evaluation)
val = read_data("val.jsonl")

# Build a simple extractive QA pipeline (CPU). Change device if you have GPU.
qa_pipeline = pipeline(
    "question-answering",
    model="distilbert/distilbert-base-cased-distilled-squad",
    tokenizer="distilbert/distilbert-base-cased-distilled-squad",
)

# SemanticF1 judge
metric = SemanticF1(decompositional=True)

def run_traditional_qa_on_first_question(example, search, qa_pipeline, metric):
    # example is one document from PragmatiCQA (topic + 'qas' list)
    if not example.get("qas"):
        return None
    qa = example["qas"][0]
    question = qa["q"]
    gold = qa["a"]

    # assemble contexts
    literal_spans = [s["text"] for s in qa["a_meta"].get("literal_obj", [])]
    pragmatic_spans = [s["text"] for s in qa["a_meta"].get("pragmatic_obj", [])]
    literal_context = " ".join(literal_spans).strip()
    pragmatic_context = " ".join(pragmatic_spans).strip()

    # retrieved context using the notebook's retriever 'search'
    retrieved_passages = search(question).passages
    retrieved_context = " ".join(retrieved_passages).strip()

    def answer_from_context(q, ctx):
        if not ctx:
            return ""
        out = qa_pipeline(question=q, context=ctx)
        # pipeline returns dict with 'answer'
        return out.get("answer", "") if isinstance(out, dict) else (out[0].get("answer","") if out else "")

    pred_literal = answer_from_context(question, literal_context)
    pred_pragmatic = answer_from_context(question, pragmatic_context)
    pred_retrieved = answer_from_context(question, retrieved_context)

    # Prepare dspy.Example for evaluation (use retrieved_context as input context)
    gold_ex = dspy.Example(question=question, response=gold, inputs={"context": retrieved_context})
    lit_ex = dspy.Example(question=question, response=pred_literal, inputs={"context": retrieved_context})
    prag_ex = dspy.Example(question=question, response=pred_pragmatic, inputs={"context": retrieved_context})
    retr_ex = dspy.Example(question=question, response=pred_retrieved, inputs={"context": retrieved_context})

    scores = {
        "literal": metric(gold_ex, lit_ex),
        "pragmatic": metric(gold_ex, prag_ex),
        "retrieved": metric(gold_ex, retr_ex),
    }

    return {
        "question": question,
        "gold": gold,
        "pred_literal": pred_literal,
        "pred_pragmatic": pred_pragmatic,
        "pred_retrieved": pred_retrieved,
        "scores": scores
    }

# Quick smoke-run on the first 10 validation documents (or fewer)
results = []
for i, doc in enumerate(val[:10]):
    res = run_traditional_qa_on_first_question(doc, search, qa_pipeline, metric)
    if res:
        print(f"Example {i+1}:")
        print("Q:", res["question"])
        print("Gold (truncated):", (res["gold"][:200] + "...") if len(res["gold"])>200 else res["gold"])
        print("Pred (literal):", res["pred_literal"])
        print("Pred (pragmatic):", res["pred_pragmatic"])
        print("Pred (retrieved):", res["pred_retrieved"])
        print("Scores:", res["scores"])
        print("-"*80)
        results.append(res["scores"])