# RAG Data Valuation: Real-World Experiment
This notebook performs an experiment using the **SQuAD** dataset to verify if Data Valuation (DV) can improve RAG performance compared to a baseline retrieval approach.

In [None]:

import sys
import os
import pandas as pd
import matplotlib.pyplot as plt
from datasets import load_dataset
from tqdm.auto import tqdm

# Add src to path
sys.path.append(os.path.abspath('../../'))

from src.dv.models.entities import Chunk
from src.dv.algorithms.loo import LOOValuator
from src.dv.evaluation.judges import MNLIJudge
from src.dv.core import ValuationSuite
from src.dv.evaluation.filtering import filter_negative_chunks


## 1. Configuration
We'll use a subset of SQuAD and the MNLI Judge for evaluation.

In [None]:

config = {
    "dataset": "squad",
    "num_samples": 5, # Small number for demonstration
    "dv_methods": ["LOO"],
    "judge_type": "mnli",
}

print(f"Running real-world experiment with config: {config}")


## 2. Data Loading & RAG Simulation
We load SQuAD and simulate a RAG process where we retrieve relevant and irrelevant chunks for a query.

In [None]:

def load_squad_samples(n=5):
    dataset = load_dataset("squad", split="validation", streaming=True)
    samples = []
    for item in dataset.take(n):
        # In a real RAG, these would come from a vector DB
        # Here we use the ground truth context + some noise/distractors
        query = item["question"]
        answer = item["answers"]["text"][0]
        context = item["context"]
        
        # Split context into sentences as chunks
        raw_chunks = context.split(". ")
        chunks = [Chunk(id=f"c{i}", text=c) for i, c in enumerate(raw_chunks[:5])]
        
        # Add a distractor chunk
        chunks.append(Chunk(id="distractor", text="The moon is made of green cheese and cats like milk."))
        
        samples.append({
            "query": query,
            "answer": answer,
            "chunks": chunks
        })
    return samples

data = load_squad_samples(config["num_samples"])
print(f"Loaded {len(data)} samples.")


## 3. Experiment: Baseline vs DV-Filtered
**Baseline**: All retrieved chunks used for context.
**DV-Filtered**: Chunks with negative marginal contribution (LOO) are removed.

In [None]:

# Initialize Judge and Valuators
judge = MNLIJudge()
valuators = {"LOO": LOOValuator(judge)}
suite = ValuationSuite(valuators)

results = []

for item in tqdm(data):
    query = item["query"]
    answer = item["answer"]
    chunks = item["chunks"]
    
    # Baseline Context
    baseline_context = " ".join([c.text for c in chunks])
    baseline_faithfulness = judge.get_faithfulness(query, baseline_context, answer)
    
    # Run Data Valuation
    dv_results = suite.evaluate_all(query, chunks, answer)
    
    # DV-Filtered Context
    filtered_chunks = filter_negative_chunks(chunks, dv_results)
    filtered_context = " ".join([c.text for c in filtered_chunks])
    filtered_faithfulness = judge.get_faithfulness(query, filtered_context, answer)
    
    results.append({
        "query": query,
        "baseline_faith": baseline_faithfulness,
        "filtered_faith": filtered_faithfulness,
        "improvement": filtered_faithfulness - baseline_faithfulness,
        "num_removed": len(chunks) - len(filtered_chunks)
    })

exp_df = pd.DataFrame(results)


## 4. Results Analysis
Comparing the average faithfulness scores.

In [None]:

avg_baseline = exp_df["baseline_faith"].mean()
avg_filtered = exp_df["filtered_faith"].mean()

print(f"Average Baseline Faithfulness: {avg_baseline:.4f}")
print(f"Average Filtered Faithfulness: {avg_filtered:.4f}")
print(f"Average Improvement: {(avg_filtered - avg_baseline):.4f}")

# Plotting
exp_df[["baseline_faith", "filtered_faith"]].plot(kind="bar", figsize=(12, 6))
plt.title("Faithfulness Comparison: Baseline vs DV-Filtered")
plt.xlabel("Sample Index")
plt.ylabel("Faithfulness Score")
plt.legend(["Baseline", "DV-Filtered"])
plt.show()

# Histogram of improvement
exp_df["improvement"].hist(bins=10)
plt.title("Distribution of Faithfulness Improvement")
plt.xlabel("Improvement")
plt.ylabel("Frequency")
plt.show()
