This project provides a framework for building and evaluating interpretable Retrieval-Augmented Generation (RAG) pipelines. The core of the project is the ability to explain the contribution of retrieved documents and query parts to the final generated output, using Shapley values.
- Interpretable RAG Pipeline: A RAG pipeline that can explain the importance of retrieved documents and query components.
- Shapley Value-based Explanations: Uses Shapley values to quantify the contribution of each input feature (e.g., retrieved documents, query tokens) to the generated output.
- Multiple Explanation Aggregations: Supports different aggregation methods for Shapley values, such as token-level, sequence-level, and bag-of-words.
- Online and Offline Retrieval: Supports both online and offline retrieval methods.
- Extensible and Modular: The code is organized in a modular way, making it easy to extend and adapt to different models and datasets.
This option is mainly useful for development or replication our experiments. First clone this repository.
After cloning or downloading the repository, run the Linux shell script ./setup.sh.
It will initialize the workspace by performing the following steps:
- It will install the required Python modules by running
pip install -r "./requirements.txt" - It will download the necessary Python code to compute the BARTScore by Yuan et al. (2021) to "./resources/bart_score.py".
- It will download the necessary Python code to compute the LongDocFACTScore by Bishop et al. (2024) to "./resources/ldfacts.py".
The classes ExplainableAutoModelForRetrieval and ExplainableAutoModelForGeneration wrap Hugging Face transformers models and provide the functionality to compute attribution scores.
For retrieval, RAG-E supports the following token attribution methods:
.grad(...): Raw gradients towards the inputs of the last batch.aGrad(...): AGrad (da ⊙ a) by Liu et al. (2021) of the last batch..gradIn(...): Gradient times input (dx ⊙ x) scores of the last batch..intGrad(...): Integrated Gradients by Sundararajan et al. (2017) of the last batch.
For generation RAG-E supports precise and Kernel-SHAP approximated Shapley Values for document attribution.
The ExplainableAutoModelForRAG class combines these models to create the full interpretable RAG pipeline.
Here is an Example with Llama 3.1 8B and Snowflake Arctic Embed v2 (more examples in scripts/demos):
from src.Interpretable_RAG.rag import ExplainableAutoModelForRAG
# Load Pipeline:
model = ExplainableAutoModelForRAG(
# Retriever info:
query_encoder_name_or_path='Snowflake/snowflake-arctic-embed-l-v2.0',
retriever_query_format='query: {query}',
retriever_token_processor=lambda s: s.replace('▁', ' '),
retriever_kwargs={'add_pooling_layer':False},
# Generator info:
generator_name_or_path='meta-llama/Llama-3.1-8B-Instruct',
generator_token_processor=lambda s: s.replace('Ġ', ' ').strip('Ċ'),
generator_kwargs={'device_map':'auto', 'torch_dtype':torch.bfloat16}
)
# MSMarco query and passage as an example:
query = "Where was Marie Curie born?"
contexts = [
"Maria Skłodowska, later known as Marie Curie, was born on November 7, 1867.",
"Born in Paris on 15 May 1859, Pierre Curie was the son of Eugène Curie, a doctor of French Catholic origin from Alsace.",
"Maria Skłodowska was born in Warsaw, in Congress Poland in the Russian Empire, as the fifth and youngest child of well-known teachers Bronisława, née Boguska, and Władysław Skłodowski.",
"While a French citizen, Marie Skłodowska Curie, who used both surnames, never lost her sense of Polish identity. She taught her daughters the Polish language and took them on visits to Poland.",
"Marie Curie founded the Curie Institute in Paris in 1920, and the Curie Institute in Warsaw in 1932.",
]
# Generate answer:
output = model(
query=query,
contexts=contexts,
k=5,
generator_kwargs={
'max_new_tokens':256,
'do_sample':False,
'top_p':1,
'num_beams':1,
'batch_size':64,
'max_samples_query':32,
'max_samples_context':32,
'conditional':True
}
)
# Explain retriever:
ret_attributions = model.retriever.intGrad()
# Explain generator:
gen_attributions = model.generator.get_shapley_values()To visualize these explanations, RAG-E includes easy-to-use plotting functions:
from src.Interpretable_RAG.plotting import visualize_importance_retriever, visualize_importance_generator, plot_document_importance_rag, higlight_importance_rag
# Functions to normalize tokens:
retriever_token_processor=lambda s: s.replace('▁', ' '),
generator_token_processor=lambda s: s.replace('Ġ', ' ').strip('Ċ')
# Generate highlighted tokens for the retriever:
visualize_importance_retriever(model.retriever, method='intGrad', token_processor=retriever_token_processor, show:bool=True)
# Generate highlighted tokens for the generator:
visualize_attribution_generator(model.generator, aggregation='token', token_processor=generator_token_processor, show:bool=True)
# Generate highlighted tokens for the rag pipeline:
higlight_importance_rag(model, retriever_method='intGrad', show:bool=True,
retriever_token_processor=retriever_token_processor,
generator_token_processor=generator_token_processor
)
# Plot document importance for the rag pipeline:
plot_document_importance_rag(model, show:bool=True)Replication our experiments requires installation Option 2. The main entry point for running experiments is the scripts/run_pipeline.py script. Here is an example of how to run it:
python scripts/run_pipeline.py \
--topics_path /path/to/your/topics.tsv \
--ranked_list_path /path/to/your/ranked_list.csv \
--collection_path /path/to/your/collection.tsv \
--output_path /path/to/your/output_directory/ \
--model_id meta-llama/Llama-3.1-8B-Instruct \
--num_docs_context 6 \
--max_gen_len 300 \
--run_original \
--run_randomizedArguments:
--topics_path: Path to the file with the queries.--ranked_list_path: Path to the ranked list from the retrieval.--collection_path: Path to the collection of passages.--output_path: Base directory to save all the results.--model_id: ID of the generative model from Hugging Face.--num_docs_context: Number of documents to use as context.--max_gen_len: Maximum length of the generated response.--run_original: Run the experiment with the original contexts.--run_randomized: Run with contexts in random order.--run_no_duplicates: Run with contexts without duplicates.
src/Interpretable_RAG: Contains the core source code for the interpretable RAG pipeline.scripts: Contains scripts for running experiments, building indexes, and performing analysis.data: Contains the data used for the experiments.resources: Contains additional resources, such as evaluation scripts.outputs_evaluation: Contains the evaluation outputs.outputs_retrieved: Contains the retrieved outputs.results: Contains the results of the experiments.plots: Contains plots for the analysis of the results.
Contributions are welcome! Please feel free to submit a pull request or open an issue.
This project is licensed under the GPL v3 License. See the LICENSE file for details.