<a href="https://www.kaggle.com/code/emmermarcell/rag-pipeline-on-the-wikipedia-dataset?scriptVersionId=159940868" target="_blank"><img align="left" alt="Kaggle" title="Open in Kaggle" src="https://kaggle.com/static/images/open-in-kaggle.svg"></a>

# Implementing a RAG pipeline on the Wikipedia dataset



The embedding of the Wikipedia articles are done with the [`all-MiniLM-L6-v2`][4] model from the [Sentence-Transformers][5] library.  The strings are embedded into a $384$ dimensional vector space where a similarity search is performed by the `faiss.IndexFlatL2` index based on their Euclidean (L2) distance.

The notebook runs on 2xT4 GPUs that Kaggle provides.
A great resource for training faiss on multiple GPUs can be found on the [faiss github site][2]. Furthermore, for computing embeddings on multiple GPUs I reference the [Sentence-Transformers github site][3].

After the embedding and ranking of athe article chunks, I employ the [`distilbert-base-cased-distilled-squad`][9] Q&A pipeline, a fine-tuned version of the [`DistilBERT-base-cased`][10] model using (a second step of) knowledge distillation on the [`SQuAD v1.1`][11] dataset.

I used the following articles as a starting point for implementing a RAG pipeline:

* [Akriti Upadhyay - Implementing RAG with Langchain and Hugging Face][6]

* [Vladimir Blagojevic - Ask Wikipedia ELI5-like Questions Using Long-Form Question Answering on Haystack][7]

* [Steven van de Graaf - Pre-processing a Wikipedia dump for NLP model training — a write-up][12]

[1]: https://huggingface.co/learn/nlp-course/chapter5/4?fw=pt
[2]: https://github.com/facebookresearch/faiss/blob/main/tutorial/python/5-Multiple-GPUs.py
[3]: https://github.com/UKPLab/sentence-transformers/blob/master/examples/applications/computing-embeddings/computing_embeddings_multi_gpu.py
[4]: https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2
[5]: https://www.sbert.net/
[6]: https://medium.com/international-school-of-ai-data-science/implementing-rag-with-langchain-and-hugging-face-28e3ea66c5f7q=implementing+rag+with+langchain+and+huggingface&oq=implementing+rag+with+langchain+and+huggingface&gs_lcrp=EgZjaHJvbWUyBggAEEUYOTIKCAEQABiABBiiBDIKCAIQABiABBiiBNIBCDc2MzFqMGo3qAIAsAIA&client=ubuntu-chr&sourceid=chrome&ie=UTF-8
[7]: https://medium.com/international-school-of-ai-data-science/implementing-rag-with-langchain-and-hugging-face-28e3ea66c5f7
[8]: https://huggingface.co/datasets/wikipedia
[9]: https://huggingface.co/distilbert-base-cased-distilled-squad
[10]: https://huggingface.co/distilbert-base-cased
[11]: https://huggingface.co/datasets/squad
[12]: https://towardsdatascience.com/pre-processing-a-wikipedia-dump-for-nlp-model-training-a-write-up-3b9176fdf67

In [None]:
!pip install sentence_transformers
!pip install faiss-gpu

In [None]:
import gc    # Garbage collector
import logging
from tqdm.auto import tqdm

import numpy as np
from datasets import load_dataset
from transformers import pipeline
from sentence_transformers import SentenceTransformer, LoggingHandler
import faiss


logging.basicConfig(
    format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO, handlers=[LoggingHandler()]
)

# Ensure you have a GPU available
ngpus = faiss.get_num_gpus()
print("number of GPUs:", ngpus)

In [None]:
# Important, you need to shield your code with if __name__. Otherwise, CUDA runs into issues when spawning new processes.
if __name__ == "__main__":
    wiki_corpus_path = '/kaggle/input/wikipedia-sentences/wikisent2.txt'
    # Load in the processed articles into a new Hugging Face dataset
    processed_wiki_dataset = load_dataset("text", data_files={"train": wiki_corpus_path})
    print(f'Length of the Wikipedia dataset is {len(processed_wiki_dataset)} articles.')
    
    # Define the sentence transformer model
    model_name = "all-MiniLM-L6-v2"
    model = SentenceTransformer(model_name)
    embedding_dim = model.get_sentence_embedding_dimension()    # Get the embedding dimension
    max_seq_len = model.max_seq_length    # Maximum sequence length in words
    print(f'The embedding dimension of the all-MiniLM-L6-v2 model is {embedding_dim}.')
    print(f"Max sequence lenght of the {model_name} model is {max_seq_len}.")
    
    # Initialize a FAISS index (for CPU)
    cpu_index = faiss.IndexFlatL2(embedding_dim)

    # Initialize GPU resources for FAISS
    gpu_index = faiss.index_cpu_to_all_gpus(  # build the index
        cpu_index
    )        
        
    # Start the multi-process pool on all available CUDA devices
    pool = model.start_multi_process_pool()
    
    # Batch processing with tqdm progress bar
    # Define batch size based on the system's memory capacity
    batch_size = 2**13
    total_batches = len(processed_wiki_dataset['train']['text']) // batch_size + (0 if len(processed_wiki_dataset['train']['text']) % batch_size == 0 else 1)
    
    for i in tqdm(range(0, len(processed_wiki_dataset['train']['text']), batch_size), total=total_batches, desc="Processing Batches"):
        # Take the next batch of articles
        batch_texts = processed_wiki_dataset['train']['text'][i:i + batch_size]
        # Compute the embeddings using the multi-process pool
        batch_embeddings = model.encode_multi_process(batch_texts, pool)
        # Add embeddings to the GPU index
        gpu_index.add(batch_embeddings)
        
        # Memory management
        del batch_embeddings
        gc.collect()
        
    # Function to search for relevant articles using GPU
    def search_wiki_articles(question):
        question_embedding = model.encode_multi_process(question, pool)
        distances, indices = gpu_index.search(question_embedding, k=5)
        return [processed_wiki_dataset['train']['text'][i] for i in indices[0]]

    # State business questions
    questions = [
        'What services does KPMG offer to its clients?',
        'What are the key considerations when assessing internal controls during an audit?',
        'How do you stay updated on changes in tax laws and regulations affecting clients?',
        "What steps do you take to understand a client's business before initiating a consulting project?",
        'What due diligence processes are crucial for evaluating the financial health of a potential acquisition?',
    ]
    
    relevant_article_chunks = [search_wiki_articles(question) for question in questions]
    
    # Example
    print(f'Question:\n{questions[0]}')
    print(f'Relevant Wikipedia article chunks:\n{relevant_article_chunks[0]}')
    
    # Optional: Stop the processes in the pool
    model.stop_multi_process_pool(pool)
    
    # (Optional) Save the faiss index
    # faiss.write_index(gpu_index, 'Wikipedia_FlatL2.index')

The maximum length of tokens the we can feed to the `distilbert-base-cased-distilled-squad` tokenizer before truncation is 512. Therefore it is more than enough to search for the k=5 kNN of article chunks for a given question using the faiss index.

In [None]:
# Load the Q&A pipeline
qa_model_name = 'distilbert-base-cased-distilled-squad'
question_answerer = pipeline('question-answering', model=qa_model_name, tokenizer=qa_model_name)
print(f'The maximum length of tokens the we can feed to the tokenizer before truncation is {question_answerer.tokenizer.model_max_length}.')

# Function to perform question-answering given a question and a list of documents
def answer_question(question, article_chunks):
    # Combine the article chunks into a single string
    context = ' '.join(article_chunks)

    # Perform question-answering
    result = question_answerer(question=question, context=context)

    return result

# Answer the questions
results = [answer_question(questions[idx], relevant_article_chunks[idx]) for idx in range(len(questions))]

#for result in results:
#    print(f"Answer: '{results[idx]['answer']}', score: {round(results[idx]['score'], 4)}, start: {results[idx]['start']}, end: {results[idx]['end']}")
    
for idx in range(len(questions)):
    print(f'Question:\n{questions[idx]}')
    print(f"Answer:\n'{results[idx]['answer']}',\nscore:{round(results[idx]['score'], 4)}, start:{results[idx]['start']}, end:{results[idx]['end']}")
    print('='*30)

## Evaluation of the pipeline

For evaluation a Q&A pipeline, one an use the [Official Evaluation Script][1] of the [SQuAD v2.0][2] dataset. I inclued a part of the script that I can evaluate the Exact and the F1 score of a pipeline on this dataset.

[1]: https://worksheets.codalab.org/rest/bundles/0x6b567e1cf2e041ec80d7098f031c5c9e/contents/blob/
[2]: https://rajpurkar.github.io/SQuAD-explorer/

In [None]:
"""Official evaluation script for SQuAD version 2.0.

In addition to basic functionality, we also compute additional statistics and
plot precision-recall curves if an additional na_prob.json file is provided.
This file is expected to map question ID's to the model's predicted probability
that a question is unanswerable.
"""
import argparse
import collections
import json
import numpy as np
import os
import re
import string
import sys

OPTS = None

def parse_args():
    parser = argparse.ArgumentParser('Official evaluation script for SQuAD version 2.0.')
    parser.add_argument('data_file', metavar='data.json', help='Input data JSON file.')
    parser.add_argument('pred_file', metavar='pred.json', help='Model predictions.')
    parser.add_argument('--out-file', '-o', metavar='eval.json',
                        help='Write accuracy metrics to file (default is stdout).')
    parser.add_argument('--na-prob-file', '-n', metavar='na_prob.json',
                        help='Model estimates of probability of no answer.')
    parser.add_argument('--na-prob-thresh', '-t', type=float, default=1.0,
                        help='Predict "" if no-answer probability exceeds this (default = 1.0).')
    parser.add_argument('--out-image-dir', '-p', metavar='out_images', default=None,
                        help='Save precision-recall curves to directory.')
    parser.add_argument('--verbose', '-v', action='store_true')
    if len(sys.argv) == 1:
        parser.print_help()
        sys.exit(1)
    return parser.parse_args()

def normalize_answer(s):
    """Lower text and remove punctuation, articles and extra whitespace."""
    def remove_articles(text):
        regex = re.compile(r'\b(a|an|the)\b', re.UNICODE)
        return re.sub(regex, ' ', text)
    def white_space_fix(text):
        return ' '.join(text.split())
    def remove_punc(text):
        exclude = set(string.punctuation)
        return ''.join(ch for ch in text if ch not in exclude)
    def lower(text):
        return text.lower()
    return white_space_fix(remove_articles(remove_punc(lower(s))))

def get_tokens(s):
    if not s:
        return []
    return normalize_answer(s).split()

def compute_exact(a_gold, a_pred):
    return int(normalize_answer(a_gold) == normalize_answer(a_pred))

def compute_f1(a_gold, a_pred):
    gold_toks = get_tokens(a_gold)
    pred_toks = get_tokens(a_pred)
    common = collections.Counter(gold_toks) & collections.Counter(pred_toks)
    num_same = sum(common.values())
    if len(gold_toks) == 0 or len(pred_toks) == 0:
        # If either is no-answer, then F1 is 1 if they agree, 0 otherwise
        return int(gold_toks == pred_toks)
    if num_same == 0:
        return 0
    precision = 1.0 * num_same / len(pred_toks)
    recall = 1.0 * num_same / len(gold_toks)
    f1 = (2 * precision * recall) / (precision + recall)
    return f1

In [None]:
# Load in the squad dataset
squad_dataset = dataset = load_dataset('squad', split='train')

# Example use
question = squad_dataset['question'][0]
context = squad_dataset['context'][0]
answer = squad_dataset['answers'][0]

# Pedict the answer using the distilbert Q&A pipeline
result = question_answerer(question=question, context=context)

# Compute the evluation scores
exact_score = compute_exact(a_gold=answer['text'][0], a_pred=result['answer'])
f1_score = compute_f1(a_gold=answer['text'][0], a_pred=result['answer'])


print(f'Question:\n{question}')
print(f"Predited Answer:\n{result['answer']}")
print(f"Proper answer accoding to SQuAD:\n{answer['text'][0]}")
print(f'F1: {f1_score},\tE: {exact_score}')

### Making the code accessible for end-users

Can be done e.g. with a simple GUI. Some references for how it can be done:

* [Kamalraj M M - Step By Step Guide to Integrate LLM with GUI: Improving Performance Of GUI with LLM][1]
* [PySimpleGUI][2]
* Making a simple website from scratch using Flask or Django

[1]: https://www.youtube.com/watch?v=nWi8yM4bCmM
[2]: https://www.pysimplegui.org/en/latest/#jump-start