# Intro
This notebook builds a simple dataset which will be a base for our RAG system.
The dataset will be composed of random wikipedia pages. It makes a good corpus to practice building
RAGs for a common usecase: internal documentation question answering. There are of course a few differences, including
scale (depending on the company) and lack of specialised lingo and concepts which are often out of
the LLMs training distribution (Wikipedia is actually in pre-training corpus of many LLMs).

Some chunking methods implemented can also be found in popular libraries (eg: LangChain). We are rewriting them for fun here.

In [None]:
import os
import re
from typing import Callable, Iterable, Dict, Any
from tqdm import tqdm

import numpy as np
from datasets import Dataset, load_dataset, load_from_disk
from huggingface_hub import InferenceClient

LOCAL_DATASET_FOLDER = "local_datasets"

In [None]:
# Saves some time to avoid fetching and parsing pages on your own
# by loading a HF dataset of wikipedia pages
wiki_data = load_dataset("wikipedia", "20220301.simple")["train"]

In [None]:
wiki_data

In [None]:
# Have a look at one element
first_el = wiki_data[0].copy() # for safe edits
first_el["text"] = first_el["text"][:200] + "..." # for easier readability
first_el

We are pretty close to the dataset format we need for indexing. The main blocker is that the text field is too long for the limited context size of some embedding models we'd like to use (eg: BERT uses a context of ~512 token which is about ~512 words). We could also use
larger models with larger context sizes, but research also suggest that models tend to lose track of some information in large context sizes: https://huggingface.co/papers/2307.03172.

As a result, it seems favorable to keep using a relatively small context size -> We'll need to chunk our text examples. The rest of the notebook plays with different methods to do it.

In [None]:
# Define utils to apply a chunking method to the dataset per batch
# We'll define different chunking methods to use after this
def get_chunk_from_batch(
    examples_batch: Iterable,
    chunk_text_method: Callable,
    **chunk_text_method_kwargs: Dict[str, int]
) -> Dict[str, Iterable[Any]]:
    """
    Apply 'chunk_text_method' to the examples_batch and returns
    a dictionnary in the formatexpected by the Dataset.map method (Dict[str, Iterable[Features values]])
    """
    example_ids = []
    example_urls = []
    example_titles = []
    chunks = []
    for ind, example_text in enumerate(examples_batch["text"]):
        for chunk in chunk_text_method(example_text, **chunk_text_method_kwargs):
            example_ids.append(examples_batch["id"][ind])
            example_titles.append(examples_batch["title"][ind])
            example_urls.append(examples_batch["url"][ind])
            chunks.append(chunk)
    return {
        "id": list(range(len(chunks))),
        "original_id": example_ids,
        "title": example_titles,
        "url": example_urls,
        "text_chunk": chunks
    }

# Chunking strategies

## (Dummy) Fixed-length chunking (with some overlap)


In [None]:
CHUNK_SIZE_WORDS = 300
OVERLAP_SIZE_WORDS = 10

# Alternative to using LangChain methods
# We do it indepedently of any tokeniser to make it generic (using words as a unit), at the risk
# of having issues with model context size later on if the number of tokens in the chunk is too high
def chunk_text_with_fixed_length(
    text: str,
    chunk_size_words: int = CHUNK_SIZE_WORDS,
    overlap_size_words: int = OVERLAP_SIZE_WORDS
) -> Iterable[str]:
    text_no_new_lines = text.replace("\n", " ")
    text_split = text_no_new_lines.split(" ")

    total_words = len(text_split)
    # iterate over words, chunk
    word_index = 0
    while word_index < total_words:
        yield " ".join(text_split[word_index:word_index+chunk_size_words])
        word_index += chunk_size_words - overlap_size_words

In [None]:
wiki_data_chunked = wiki_data.map(
    lambda example_batch: get_chunk_from_batch(example_batch, chunk_text_with_fixed_length),
    batched=True,
    remove_columns=["id", "title", "text", "url"] # Removes columns because of row expansion
)

In [None]:
wiki_data_chunked[0:2]

In [None]:
# Save data back to disk
FIXED_LENGTH_CHUNK_DATASET_NAME = f"wiki-data-chunked-fixed-length-CS{CHUNK_SIZE_WORDS}-OS{OVERLAP_SIZE_WORDS}"
wiki_data_chunked.save_to_disk(
    os.path.join(LOCAL_DATASET_FOLDER, FIXED_LENGTH_CHUNK_DATASET_NAME)
)


## Paragraph recursive chunking

In [None]:
# Parse main sections and try to use those as chunks,
# Sections that are too long are split by sub-sections, and the same logic is applied recursively
# Parsing sections is done differently depending on the document format. For markdown, we'd split on '#' then '##' etc.
# with this corpus, paragraphs and sections are split with '\n\n' and it's hard to infer sections titles besides checking the size of the section
# We can simply assume that '\n\n' represent relatively good semantic breaks, and recursively use those to break sections that are too long in 'half'

# Makes the assumption that individual paragraphs are all smaller than section size

# Prepend all sections with title and subtitle
# cut on titles, looks at section sizes, if too long, cut

In [None]:
def chunk_text_recursively_per_section(text: str, max_chunk_size_words: int = CHUNK_SIZE_WORDS) -> Iterable[str]:
    SPLIT_STR = "\n\n"
    text_split = text.split(" ")
    if len(text_split) > max_chunk_size_words and SPLIT_STR not in text:
        # We can't split the text further and it's too big, resolve to dummy chunking strategy
        return list(chunk_text_with_fixed_length(text, max_chunk_size_words, OVERLAP_SIZE_WORDS))
    elif len(text_split) <= max_chunk_size_words:
        return [text]
    else:
        # There's at least one split candidate in the text, pick the best one
        # (=the one that looks to be the closest to the middle)
        text_len = len(text)
        all_potential_splits = [m.start() for m in re.finditer(SPLIT_STR, text)]
        all_potential_splits_distances_to_half = [abs(split_ind - text_len//2) for split_ind in all_potential_splits]
        best_split_ind = all_potential_splits[all_potential_splits_distances_to_half.index(min(all_potential_splits_distances_to_half))]

        return chunk_text_recursively_per_section(text[0:best_split_ind], max_chunk_size_words) + chunk_text_recursively_per_section(text[(best_split_ind+len(SPLIT_STR)):], max_chunk_size_words)
        

In [None]:
wiki_data_chunked_recursive = wiki_data.map(
    lambda example_batch: get_chunk_from_batch(example_batch, chunk_text_recursively_per_section),
    batched=True,
    remove_columns=["id", "title", "text", "url"] # Removes columns because of row expansion
)

In [None]:
wiki_data_chunked[0:2]

In [None]:
# Save data back to disk
RECURSIVE_DATASET_NAME = f"wiki-data-chunked-recursive-CS{CHUNK_SIZE_WORDS}"
wiki_data_chunked_recursive.save_to_disk(
    os.path.join(LOCAL_DATASET_FOLDER, RECURSIVE_DATASET_NAME)
)

## [Optional - can be skipped] Modelling approach

If the previous method yields disappointing results, which could happen if the segmentation of sections is harder to work with, we could use a slightly more esoteric approach using a language model to detects interesting splitting points.

This could be done in different ways which may have varying performance depending on the dataset, a few similar ideas include:
    - Simply ask an LLM for the split points
    - Use an embedding model to capture the semantic meaning of each sentence, and add a split point where the topic seems to shift significantly
    - Use a model trained on 'Next Sentence Prediction' and add a split point where the model confidently says sentences are disconnected.

We'll try the latter here with [BERT](https://huggingface.co/google-bert/bert-base-uncased). #TODO replace with MiniLM

In [None]:
from transformers import BertTokenizer, BertForNextSentencePrediction
import torch
import nltk

In [None]:
def is_sentence_next(bert_model, bert_tokeniser, sentence_a, sentence_b):
    encoding = bert_tokeniser(sentence_a, sentence_b, return_tensors="pt")
    outputs = bert_model(**encoding, labels=torch.LongTensor([1]))

    # Decision logic to decide if we'd like to break the chunk here.
    # Finding the right threshold/logic requires some trials and errors and probably depends on the dataset used
    return outputs.logits[0, 0] > outputs.logits[0, 1] # Same sentence more likely than random
        

def chunk_text_with_bert(text: str, max_chunk_size_words: int = CHUNK_SIZE_WORDS):
    tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
    model = BertForNextSentencePrediction.from_pretrained("bert-base-uncased")
    model.eval()
    model.to("cpu") # switch to CUDA if you have a GPU

    # Split per sentence
    text_sentences = nltk.sent_tokenize(text)

    # Make sure sentences are all small enough to be sent forward. A dummy approach is to
    # truncate those.
    text_sentences = [
        " ".join(sentence.split(" ")[:int(0.8*max_chunk_size_words)])
        for sentence in text_sentences
    ]
    
    chunks = []
    last_sentence = text_sentences[0]
    current_chunk = text_sentences[0]

    for sentence in text_sentences:
        current_chunk_size_words = len(current_chunk.split(" "))
        sentence_size_words = len(sentence.split(" "))
        if current_chunk_size_words + sentence_size_words > max_chunk_size_words:
            # we have to split in any case
            chunks.append(current_chunk)
            current_chunk = sentence
        if is_sentence_next(model, tokenizer, last_sentence, sentence):
            # add to current chunk and continue
            current_chunk = current_chunk + sentence
        else:
            # split
            chunks.append(current_chunk)
            current_chunk = sentence
        last_sentence = sentence
    
    return chunks


In [None]:
# Method is quite slow and would benefit optimisation! for learning purpose only
wiki_data_chunked_w_model = wiki_data.select(range(200)).map(
    lambda example_batch: get_chunk_from_batch(example_batch, chunk_text_with_bert),
    batched=True,
    remove_columns=["id", "title", "text", "url"], # Removes columns because of row expansion
    batch_size=16
)

In [None]:
wiki_data_chunked_w_model[0:2]

In [None]:
# Save data back to disk
MODEL_DATASET_NAME = f"wiki-data-chunked-w-model-CS{CHUNK_SIZE_WORDS}"
wiki_data_chunked_w_model.save_to_disk(
    os.path.join(LOCAL_DATASET_FOLDER, MODEL_DATASET_NAME)
)

# Questions-Answers generation
It would be excellent to have human-curated questions answers pairs to evaluate our retrieval logic.
(a bit like in https://www.kaggle.com/datasets/rtatman/questionanswer-dataset?resource=download)

If we can't afford this, we can always generate questions/answers with an LLM as well.

The notebook implements both methods below, if you have time on your hands to annotate things!

In [None]:
dataset_to_add_questions_to = RECURSIVE_DATASET_NAME

In [None]:
wiki_data_chunked = load_from_disk(
    os.path.join(LOCAL_DATASET_FOLDER, dataset_to_add_questions_to)
)

In [None]:
wiki_data_chunked

## Manual hand-labeling

In [None]:
questions_dataset_name = f"{dataset_to_add_questions_to}-questions"

data = []
while True: # Interrupt when you'd like to stop
    rnd_chunk_id = np.random.randint(len(wiki_data_chunked))
    print("-----------------------")
    print("-----------------------")
    print("New chunk to annotate!")
    print("-----------------------")
    print("-----------------------")
    print(wiki_data_chunked[rnd_chunk_id]["text_chunk"])
    question = input("Type a question:")
    answer = input("Type the answer to the question:")

    new_el = {
        "chunk_id": rnd_chunk_id,
        "question": question,
        "answer" : answer
    }
    data.append(new_el)
    
    print(f"New data point: {new_el}")
    

In [None]:
data[:5]

In [None]:
questions_dataset = Dataset.from_list(data)

In [None]:
questions_dataset.save_to_disk(
    os.path.join(LOCAL_DATASET_FOLDER, questions_dataset_name)
)

## LLM labeling

In [None]:
# For fun, we can tryout the HF inference API
# os.environ["HF_TOKEN_SERVERLESS_API"] = "hf_*"
token = os.environ["HF_TOKEN_SERVERLESS_API"] # ADD YOUR TOKEN TO YOUR ENV! (It's a free service)
client = InferenceClient(
    token=token,
)

In [None]:
def fetch_question_pair_from_llm(text):
    response = client.chat_completion(
    	model="meta-llama/Meta-Llama-3-8B-Instruct",
    	messages=[
            # Prompt can be improved, LLM sometimes outputs things like "Who are notable figures mentioned in this list?"
            # which obviously doesnt work as we won't have access to the list... What would you suggest we change?
            {"role": "user", "content": "You are a helpful assistant. You will receive text chunks in quotes from users that originate from a wikipedia page. Your task will be to create a question/answer pair from this text chunk, with the answer being present in the chunk. Answer the query in the form [question] END_QUESTION [answer], nothing more. Please write short questions!"},
            {"role": "assistant", "content": "Sure! understood."},
            {"role": "user", "content": f"'{text}'"}],
    	max_tokens=50,
    )

    llm_output = response.choices[0]["message"]["content"]
    
    result = llm_output.split(" END_QUESTION ")
    if len(result.split(" END_QUESTION ")) != 2:
        print("LLM Failed! returning None")
        return None, None
    
    question, answer = result
    return question, answer

    


In [None]:
# Test it once before sending multiple requests
fetch_question_pair_from_llm(
    wiki_data_chunked[0]["text_chunk"]
)

In [None]:
N_REQUESTS = 100

data = []
for _ in tqdm(range(N_REQUESTS)):
    rnd_chunk_id = np.random.randint(len(wiki_data_chunked))
    text_chunk = wiki_data_chunked[rnd_chunk_id]["text_chunk"]
    print("--- New ELEMENT ---")
    print("Fetching a question for:")
    print(text_chunk)
    print("LLM answered:")

    question, answer = fetch_question_pair_from_llm(text_chunk)
    
    new_el = {
        "chunk_id": rnd_chunk_id,
        "question": question,
        "answer" : answer
    }
    data.append(new_el)
    
    print(f"New data point: {new_el}")

In [None]:
questions_dataset_llm = Dataset.from_list(data)

In [None]:
questions_dataset_llm_name = f"{dataset_to_add_questions_to}-questions-llm"

questions_dataset_llm.save_to_disk(
    os.path.join(LOCAL_DATASET_FOLDER, questions_dataset_llm_name)
)