# Fine-tuning Embeddings for RAG on Specific Data

#### Basic Overview of Fine-tuning Embeddings

In essence, what we want to do when we fine-tune our embedding models is very simple:

```
Move the embeddings for questions relating to a document
closer together with that document
```

We can think of fine-tuning our embedding models as follows:

1) We have some pair of text items that *should* be closer together
  - `Question`, `Document` pairs
  - EX: `Who drives the bus?`, `The bus was driven by Kyle, the Bus Driver`.

2) We use these pairs as labeled data to fine-tune our embedding model.

The process of training helps the model more accurately associate our questions with the correct documents.

#####❓ Question #1:

Describe the nuance between using Q&D pairs to train the embedding model vs. inter-document pairs/related sentences.

What caveats does this approach have? Are there any special considerations for what kind of Q's we should use?

---

**ANSWER:**

We are specifically relating *the questions* to *the documents*. This means that we are making our embedding model at the very specific task of relating potential questions to specific documents.

There are many caveats, but the main ones are:

- Your Q's should reflect the Q's of your users
- This kind of fine-tuning will (purposefully) "overfit" on your data; this is the desired result in this case.

#### 🔍Answer #1:


The nuance between using Question & Document (Q&D) pairs vs. inter-document pairs/related sentences to train an embedding model lies in the specific task and goal of the fine-tuning:

##### Q&D pairs:

- This approach specifically trains the model to relate potential user questions to relevant documents.
- It optimizes the model for the task of retrieving documents based on natural language queries.
- The model learns to map questions and their corresponding relevant documents closer together in the embedding space.

##### Inter-document pairs/related sentences:
- This approach trains the model on relationships between different parts of the corpus itself.
- It helps the model understand the overall structure and connections within the document set.
- The model learns to represent similar or related content closer together, regardless of how it might be queried.

##### Key caveats and considerations:

1. Overfitting: Fine-tuning on Q&D pairs will intentionally "overfit" the model to the specific corpus and question types. This is actually desired for a targeted retrieval system, but may reduce generalizability.

2. Question quality: The questions used should reflect real user queries as closely as possible. Using artificially generated or overly simplistic questions may not translate well to real-world performance.

3. Coverage: Ensure the Q&D pairs cover a wide range of topics and query types within the corpus to avoid blind spots.

4. Bias: Be aware that the choice of questions can introduce biases in how the model interprets and retrieves information.

5. Evaluation: It's crucial to evaluate the fine-tuned model on a separate test set to ensure it generalizes well to unseen questions.

6. Maintenance: As the corpus or typical user questions evolve, the model may need to be periodically re-fine-tuned to maintain performance.

7. Complementary approaches: In some cases, a combination of Q&D fine-tuning and inter-document relationship training might provide the best overall performance.

8. Domain specificity: Q&D fine-tuning is particularly valuable for domain-specific applications where the vocabulary and concepts might be very different from general language.

##### Summary

By focusing on Q&D pairs, the embedding model becomes highly specialized for the
task of retrieving relevant documents based on user queries (Overfitted to the training data) , which is
particularly useful for building effective retrieval augmented generation (RAG)
systems.

## Task 1: Dependencies and Boilerplate

We'll set up our `nest_asyncio` so we can leverage async loops in our Notebook.

We'll also install the required libraries we'll be using today, and set up our OpenAI API key!

In [1]:
import socket
# FIXME
# def is_remote_kernel() -> bool:
#     import ipykernel
#     connection_file = ipykernel.get_connection_file()
#     kernel_ip = connection_file.split('-')[1].split('.')[0]
#     local_ip = socket.gethostbyname(socket.gethostname())
#     return kernel_ip != local_ip
def is_remote_kernel() -> bool:
    local_ip = socket.gethostbyname(socket.gethostname())
    return local_ip != "127.0.1.1"

print(f"Is remote kernel: {is_remote_kernel()}")

Is remote kernel: True


In [2]:

%load_ext autoreload
%autoreload 2

In [3]:
import jupyter_black

jupyter_black.load(line_length=88, target_version="py39")

In [4]:
from loguru import logger

### Nest Asyncio

In [5]:
import nest_asyncio

nest_asyncio.apply()

### Provide OpenAI API Key

In [6]:
import os
from getpass import getpass

if "OPENAI_API_KEY" not in os.environ:
    os.environ["OPENAI_API_KEY"] = getpass("OpenAI API Key: ")

## Task 2: Loading Data

We'll be using a recent document released by the EU 'laying down harmonised rules on artificial intelligence and amending Regulations'.

The data can be found
[here](https://eur-lex.europa.eu/legal-content/EN/TXT/?uri=CELEX%3A32024R1689),
though we will be using a HTML version which was collected into the AIM
DataRepository.

<!-- TODO: ADD a summary of subtasks -->

### Task 2 : Cloning the Data Repository

We need to clone the source data repository.

In [7]:
import os
import tempfile
from git import Repo
from tqdm.notebook import tqdm_notebook
from typing import Optional,Tuple


class RepoManager:
    """
    Manages cloning of Git repositories with progress indication.
    """

    @staticmethod
    def clone_repo_with_progress(
        repo_url: str, clone_path: Optional[str] = None
    ) -> Tuple[Repo, str]:
        if clone_path is None:
            temp_dir = tempfile.TemporaryDirectory()
            clone_path = temp_dir.name
        os.makedirs(os.path.dirname(clone_path), exist_ok=True)
        pbar = tqdm_notebook(unit="B", unit_scale=True, desc="Cloning")

        def progress(op_code, cur_count, max_count=None, message=""):
            pbar.total = max_count
            pbar.update(cur_count - pbar.n)

        repo: Repo = Repo.clone_from(
            url=repo_url, to_path=clone_path, progress=progress
        )
        pbar.close()
        return repo, clone_path


In [8]:
from pathlib import Path
import os

repo_root_path = Path(os.getcwd()).joinpath("DataRepository")

if not is_remote_kernel():
    temp_dir = tempfile.TemporaryDirectory()
    repo_root_path = Path(temp_dir.name).resolve()

In [9]:
repo_url = "https://github.com/AI-Maker-Space/DataRepository.git"
if not repo_root_path.joinpath(".git").exists():
    RepoManager.clone_repo_with_progress(
            repo_url=repo_url, clone_path=str(repo_root_path)
    )
else:
    logger.debug(f"Repository already has already been cloned.")
logger.info(f"Repository cloned to: {str(repo_root_path)}")

[32m2024-09-19 01:23:26.460[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36m<module>[0m:[36m7[0m - [34m[1mRepository already has already been cloned.[0m


[32m2024-09-19 01:23:26.461[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m8[0m - [1mRepository cloned to: /notebooks/DataRepository[0m


In [10]:
# %cd DataRepository

### Task 2 : Document Processor Class
We will be loading HTML documents , splitting them into chunks and converting
them into `langchain_core.documents.Document` objects. 

As this is a process that can happen repeatedly, we will be creating a class to
handle this so that we adhere to best practices and good coding form.

In [11]:
import uuid
from typing import List, Optional
from langchain_community.document_loaders import UnstructuredHTMLLoader

from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_core.documents import Document
from loguru import logger


class DocumentProcessor:
    """
    Processes documents: loads, splits, and assigns unique IDs.
    """
    @logger.catch
    def __init__(
        self,
        file_path: str,
        chunk_size: int = 750,
        chunk_overlap: int = 20,
    ) -> None:
        self.file_path = Path(file_path).resolve()
        self.chunk_size = chunk_size
        self.chunk_overlap = chunk_overlap
        self.documents: Optional[List[Document]] = None
        self._sanity_check()
        self._init_loader()
        self._init_splitter()
    @logger.catch
    def _sanity_check(self) -> None:
        if not isinstance(self.chunk_size, int) or self.chunk_size <= 0:
            raise ValueError("chunk_size must be a positive integer")
        if not isinstance(self.chunk_overlap, int) or self.chunk_overlap < 0:
            raise ValueError("chunk_overlap must be a non-negative integer")
        if not self.file_path.exists():
            raise FileNotFoundError(f"File not found: {str(self.file_path)}")
    @logger.catch
    def _init_loader(self) -> None:
        self.loader = UnstructuredHTMLLoader(str(self.file_path))
    @logger.catch
    def _init_splitter(self) -> None:
        self.text_splitter = RecursiveCharacterTextSplitter(
            chunk_size=self.chunk_size,
            chunk_overlap=self.chunk_overlap,
            length_function=len,
        )
    @logger.catch
    def load_documents(self) -> List[Document]:
        if self.documents is None:
            self.documents = self.loader.load()
        return self.documents
    @logger.catch
    def split_documents(self) -> List[Document]:
        if self.documents is None:
            self.load_documents()
        if not self.documents:
            raise ValueError("No documents loaded to split")
        split_docs: List[Document] = self.text_splitter.split_documents(self.documents)
        return split_docs
    @logger.catch
    def assign_unique_ids(self, documents: List[Document]) -> List[Document]:
        if not documents:
            raise ValueError("Input document list is empty")
        id_set = set()
        for document in documents:
            new_id = str(uuid.uuid4())
            while new_id in id_set:
                new_id = str(uuid.uuid4())
            id_set.add(new_id)
            document.metadata["id"] = new_id
        return documents
    @logger.catch
    def process(self) -> List[Document]:
        split_docs = self.split_documents()
        processed_docs = self.assign_unique_ids(split_docs)
        return processed_docs

### Task 2 : Loading the EU AI Act Document

Next we're going to be using the `UnstructuredHTMLLoader` to load our HTML document into a LangChain document using the [Unstructured](https://api.python.langchain.com/en/latest/document_loaders/langchain_community.document_loaders.html.UnstructuredHTMLLoader.html) library.

In [12]:
html_path :Path = repo_root_path.joinpath("eu_ai_act.html")
processor = DocumentProcessor(f"{str(html_path)}")

<!-- Next, we'll set up a classic naive chunking strategy as we only care that the documents get parsed into chunks that we can generate synthetic questions about. -->

In [13]:
# from langchain_text_splitters import RecursiveCharacterTextSplitter

# text_splitter = RecursiveCharacterTextSplitter(
#     chunk_size = 750,
#     chunk_overlap  = 20,
#     length_function = len
# )

In [14]:
# training_documents = text_splitter.split_documents(training_documents_loaded.load())

In [15]:
# import uuid

# id_set = set()

# for document in training_documents:
#   id = str(uuid.uuid4())
#   while id in id_set:
#     id = uuid.uuid4()
#   id_set.add(id)
#   document.metadata["id"] = id

Proces the documents and load them Next we can load/split these documents as follows.

In [16]:
training_documents:List[Document] = processor.process()

Next, we'll simply use naive Python slicing to create a training, test, and
validation set to prepare our data for the next step.
The following class helps with generating randomized splits

In [17]:
import random
import math
from typing import List, Tuple
from langchain_core.documents import Document
from loguru import logger


class DocumentMixer:
    """
    Splits documents into training, validation, and test sets.
    """

    @logger.catch
    def __init__(
        self,
        documents: List[Document],
        train_ratio: Optional[float] = None,
        val_ratio: Optional[float] = None,
        test_ratio: Optional[float] = None,
        train_size: Optional[int] = None,
        val_size: Optional[int] = None,
        test_size: Optional[int] = None,
    ) -> None:
        if not documents:
            raise ValueError("The document list cannot be empty")
        self.documents = documents.copy()
        random.shuffle(self.documents)
        total_docs = len(documents)

        if train_size is not None and val_size is not None and test_size is not None:
            if train_size + val_size + test_size > total_docs:
                raise ValueError(
                    "Sum of train_size, val_size, and test_size exceeds total documents"
                )
            self.train_size = train_size
            self.val_size = val_size
            self.test_size = test_size
        elif (
            train_ratio is not None and val_ratio is not None and test_ratio is not None
        ):
            if not math.isclose(
                train_ratio + val_ratio + test_ratio, 1.0, rel_tol=1e-9
            ):
                raise ValueError("Train, validation, and test ratios must sum to 1")
            self.train_size = int(total_docs * train_ratio)
            self.val_size = int(total_docs * val_ratio)
            self.test_size = total_docs - self.train_size - self.val_size
        else:
            raise ValueError("Either sizes or ratios must be provided for splitting")

    @logger.catch
    def get_train_docs(self) -> List[Document]:
        return self.documents[: self.train_size]

    @logger.catch
    def get_val_docs(self) -> List[Document]:
        return self.documents[self.train_size : self.train_size + self.val_size]

    @logger.catch
    def get_test_docs(self) -> List[Document]:
        return self.documents[
            self.train_size
            + self.val_size : self.train_size
            + self.val_size
            + self.test_size
        ]
    @logger.catch
    def get_all_splits(self) -> Tuple[List[Document], List[Document], List[Document]]:
        train_docs: List[Document] = self.get_train_docs()
        val_docs: List[Document] = self.get_val_docs()
        test_docs: List[Document] = self.get_test_docs()
        return train_docs, val_docs, test_docs



Now, we will create the mixer and generate the splits

In [18]:
# mixer = DocumentMixer(documents=training_documents,train_ratio = 0.7,
# val_ratio = 0.15, test_ratio=0.15)

# mixer = DocumentMixer(
#     documents=training_documents,
#     train_size=300,
#     val_size=50,
#     test_size=50,
# )
mixer = DocumentMixer(
    documents=training_documents, train_ratio=0.7, val_ratio=0.15, test_ratio=0.15
)

In [19]:
# from langchain_core.documents.base import Document

training_split_documents  = mixer.get_train_docs()
val_split_documents = mixer.get_val_docs()
test_split_documents = mixer.get_test_docs()

# training_split_documents: list[Document] = training_documents[:300]
# val_split_documents = training_documents[300:350]
# test_split_documents = training_documents[350:400]

## Task 3: Constructing a Fine-tuning Dataset

Using the nodes we created above, we can finally start constructing a fine-tuning dataset utilizing OpenAI's `gpt-4o-mini` (released [July 18th](https://openai.com/index/gpt-4o-mini-advancing-cost-efficient-intelligence/)).

The basic idea here is straightforward enough:

1. We look at a document
2. We generate questions that could be answered by that node

This gives us a number of question/context pairs that we can use to fine-tune our Embeddings model.

In [20]:
# from langchain_openai import ChatOpenAI

# qa_chat_model = ChatOpenAI(
#     model="gpt-4o-mini",
#     temperature=0
# )

We'll create a simple Question Generation prompt to query `gpt-4o-mini` to generate Questions for each retrieved context.

In [21]:
# from langchain_core.prompts import ChatPromptTemplate

# qa_prompt = """\
# Given the following context, you must generate questions based on only the provided context.

# You are to generate {n_questions} questions which should be provided in the following format:

# 1. QUESTION #1
# 2. QUESTION #2
# ...

# Context:
# {context}
# """

# qa_prompt_template = ChatPromptTemplate.from_template(qa_prompt)

We'll create a simple chain to query the LLM!

In [22]:
# question_generation_chain = qa_prompt_template | qa_chat_model

There's a lot going on in this function - let's take a deeper look:

1. First, we provide a list of documents and a number of questions
2. We, for each document in our list, generate `n_questions` of questions.
3. We then associate those questions and contexts via a `UUID`.

> NOTE: The reason we're doing this `UUID` association is for ease of use later in the notebook.

##### 🏗️ Activity #1:

We have:

- Lists of `Documents` with the `metadata` field `id`.

We need:

- An object with key `id`, which have values `str` questions.
- An object with key `question_id`, which have values `List(str)` which will be a list of associated `context_id`.

An Example:

question_object:
```python
{
'b4b95fb6-f827-4454-aa5b-20e62733f172': 'What types of accessible formats are available for persons with disabilities?',
'df58ee4f-714c-419e-8324-94e5870574e2': 'How do accessible formats benefit persons with disabilities?',
'505fce8b-0e56-48de-a251-61027e396918': 'What are some of the risks associated with the increasing capabilities of AI systems that generate synthetic content?',
'8ff0ab33-60dc-4fee-8958-91bfb686aca8': 'Why is it important for providers of AI systems to embed technical solutions for marking and detecting synthetic content?'
}
 ```

 context_object:
 ```python
{
'b4b95fb6-f827-4454-aa5b-20e62733f172': ['dd75bf94-75f3-4603-8e4b-5522f6925638'],
'df58ee4f-714c-419e-8324-94e5870574e2': ['dd75bf94-75f3-4603-8e4b-5522f6925638'],
'505fce8b-0e56-48de-a251-61027e396918': ['ffe3893f-688c-48e8-90bd-7a9feb953d90'],
'8ff0ab33-60dc-4fee-8958-91bfb686aca8': ['ffe3893f-688c-48e8-90bd-7a9feb953d90'],
}
 ```

 As you can see, a piece of context can be associated with more than 1 question.

 The task is to write the Python function(s) to accomplish this task.

 Your function signature is provided below, along with the desired return values.

 > NOTE: You can make any modifications that you desire - assuming that you have the correct input and outputs.

In [23]:
# import tqdm
# import uuid


# def create_questions(documents, n_questions):# -> tuple[dict[Any, Any], dict[Any, Any]]:
#     questions = {}
#     relevant_docs = {}

#     for doc in tqdm.tqdm(documents, desc="Processing documents"):
#         context = doc.page_content
#         doc_id = doc.metadata["id"]

#         # Generate questions using the question_generation_chain
#         generated_questions = question_generation_chain.invoke(
#             {"context": context, "n_questions": n_questions}
#         )

#         # Check if generated_questions is a string or a list
#         if isinstance(generated_questions.content, str):
#             question_list = generated_questions.content.strip().split("\n")
#         elif isinstance(generated_questions.content, list):
#             question_list = generated_questions.content
#         else:
#             raise ValueError(
#                 f"Unexpected type for generated_questions: {type(generated_questions.content)}"
#             )

#         for q in question_list:
#             # If q is a dict, assume it contains the question
#             if isinstance(q, dict):
#                 q = q.get("question", "")

#             # Remove numbering and any leading/trailing whitespace
#             q = q.split(".", 1)[-1].strip() if isinstance(q, str) else str(q)

#             # Generate a unique ID for the question
#             question_id = str(uuid.uuid4())

#             # Add the question to the questions dictionary
#             questions[question_id] = q

#             # Add the document ID to the relevant_docs dictionary
#             relevant_docs[question_id] = [doc_id]

#     return questions, relevant_docs

we aggregagate question generator logic in the following class and use it
instead of using module level functions to follow object-oriented programming principles.

In [24]:
# import uuid
# from typing import List, Dict, Optional, Tuple, Union
# from langchain_core.documents import Document
# from langchain_core.prompts import ChatPromptTemplate
# from langchain_openai import ChatOpenAI
# from tqdm.notebook import tqdm as tqdm_notebook
# import concurrent.futures
# from datasets import Dataset
# from loguru import logger
# import json
# from langchain_core.messages import BaseMessage


# class DatasetGenerator:
#     def __init__(self, documents: List[Document], model_name: str = "gpt-4o-mini"):
#         if not documents or not all(isinstance(doc, Document) for doc in documents):
#             raise ValueError("documents must be a non-empty list of Document objects")

#         self.documents = documents
#         self.qa_chat_model = ChatOpenAI(model=model_name, temperature=0)
#         self.qa_prompt_template = ChatPromptTemplate.from_template(
#             """
#             Given the following context, generate {n_questions} questions based only on the provided context.
#             Format the questions as a numbered list:

#             1. QUESTION #1
#             2. QUESTION #2
#             ...

#             Context:
#             {context}
#             """
#         )
#         self.questions: Dict[str, str] = {}
#         self.relevant_contexts: Dict[str, List[str]] = {}
#         self.corpus: Dict[str, str] = {}
#         self._dataset_cache: Optional[Dataset] = None
#         self._processed: bool = False
#         logger.info(
#             f"DatasetGenerator initialized with {len(documents)} documents and model {model_name}"
#         )

#     def process(
#         self,
#         n_questions: int = 2,
#         max_workers: Optional[int] = None,
#     ) -> None:
#         if self._processed:
#             logger.info("Dataset already processed. Returning without reprocessing.")
#             return

#         logger.info(f"Processing dataset with {n_questions} questions per document")

#         self._generate_corpus()
#         self._generate_questions(n_questions, max_workers)
#         self._generate_relevant_contexts()

#         self._processed = True
#         logger.info(
#             f"Processed dataset with {len(self.questions)} questions and {len(self.corpus)} documents"
#         )

#     def _generate_corpus(self) -> None:
#         logger.info("Generating corpus")
#         self.corpus = {doc.metadata["id"]: doc.page_content for doc in self.documents}
#         logger.debug(f"Generated corpus with {len(self.corpus)} documents")

#     def _generate_questions(
#         self, n_questions: int, max_workers: Optional[int],
#     ) -> None:
#         logger.info(f"Generating questions with {n_questions} questions per document")
#         self.questions.clear()

#         if max_workers is not None:
#             logger.info("Running LLM queries in parallel")
#             with concurrent.futures.ThreadPoolExecutor(
#                 max_workers=max_workers
#             ) as executor:
#                 list(
#                     tqdm_notebook(
#                         executor.map(
#                             lambda doc: self._process_single_document(doc, n_questions),
#                             self.documents,
#                         ),
#                         total=len(self.documents),
#                         desc="Processing documents",
#                     )
#                 )
#         else:
#             logger.info("Running LLM queries sequentially")
#             for doc in tqdm_notebook(self.documents, desc="Processing documents"):
#                 self._process_single_document(doc, n_questions)

#     def _process_single_document(self, doc: Document, n_questions: int) -> None:
#         doc_questions, doc_id = self._process_document(doc, n_questions)
#         for question in doc_questions:
#             question_id = str(uuid.uuid4())
#             self.questions[question_id] = question
#             logger.trace(f"Generated question: {question} for document {doc_id}")

#     def _generate_relevant_contexts(self) -> None:
#         logger.info("Generating relevant contexts")
#         self.relevant_contexts = {
#             q_id: [doc.metadata["id"]]
#             for doc in self.documents
#             for q_id in self.questions
#         }
#         logger.debug(
#             f"Generated relevant contexts for {len(self.relevant_contexts)} questions"
#         )

#     def _process_document(
#         self, doc: Document, n_questions: int
#     ) -> Tuple[List[str], str]:
#         context = doc.page_content
#         doc_id = doc.metadata["id"]
#         generated_questions: BaseMessage = self.qa_chat_model.invoke(
#             self.qa_prompt_template.format(context=context, n_questions=n_questions)
#         )
#         processed_questions = self._process_model_output(
#             generated_questions.content, n_questions
#         )
#         logger.trace(
#             f"Generated {len(processed_questions)} questions for document {doc_id}"
#         )
#         return processed_questions, doc_id

#     @staticmethod
#     def _process_model_output(
#         content: Union[str, List[Union[str, Dict]]], n_questions: int
#     ) -> List[str]:
#         processed_questions = []
#         if isinstance(content, str):
#             questions = content.strip().split("\n")
#         elif isinstance(content, list):
#             questions = content
#         else:
#             raise ValueError(f"Unexpected content type: {type(content)}")

#         for q in questions:
#             if isinstance(q, dict):
#                 q = q.get("question", "")
#             if isinstance(q, str):
#                 q = q.split(".", 1)[-1].strip()
#                 if q:
#                     processed_questions.append(q)

#         if len(processed_questions) != n_questions:
#             logger.warning(
#                 f"Expected {n_questions} questions, but got {len(processed_questions)}"
#             )

#         return processed_questions

#     def get_dataset(self) -> Dataset:
#         if not self._processed:
#             raise ValueError("Dataset has not been processed. Call process() first.")
        
#         logger.info("Creating Dataset object")
#         try:
#             if self._dataset_cache is None:
#                 questions = list(self.questions.values())
#                 relevant_contexts = list(self.relevant_contexts.values())
#                 corpus_texts = list(self.corpus.values())
#                 corpus_ids = list(self.corpus.keys())

#                 # Ensure all lists have the same length
#                 n_questions = len(questions)
#                 n_docs = len(corpus_texts)

#                 if n_questions > n_docs:
#                     logger.warning(f"More questions ({n_questions}) than documents ({n_docs}). Truncating questions.")
#                     questions = questions[:n_docs]
#                     relevant_contexts = relevant_contexts[:n_docs]
#                 elif n_questions < n_docs:
#                     logger.warning(f"Fewer questions ({n_questions}) than documents ({n_docs}). Padding with empty strings.")
#                     questions.extend([""] * (n_docs - n_questions))
#                     relevant_contexts.extend([[] for _ in range(n_docs - n_questions)])
#                 self._dataset_cache = Dataset.from_dict({
#                     "questions": questions,
#                     "relevant_contexts": relevant_contexts,
#                     "corpus": corpus_texts,
#                     "doc_id": corpus_ids
#                 })
            
#             if self._dataset_cache is not None:
#                 logger.debug(f"Returning Dataset with {len(self._dataset_cache)} entries")
#             else:
#                 logger.warning("Dataset cache is None, returning empty Dataset")
#                 return Dataset.from_dict({"questions": [], "relevant_contexts": [], "corpus": [], "doc_id": []})
            
#             return self._dataset_cache
#         except Exception as e:
#             logger.error(f"Failed to create Dataset: {str(e)}")
#             raise
#     def get_corpus(self) -> Dict[str, str]:
#         if not self._processed:
#             raise ValueError("Dataset has not been processed. Call process() first.")
#         return self.corpus

#     def get_questions(self) -> Dict[str, str]:
#         if not self._processed:
#             raise ValueError("Dataset has not been processed. Call process() first.")
#         return self.questions

#     def get_relevant_contexts(self) -> Dict[str, List[str]]:
#         if not self._processed:
#             raise ValueError("Dataset has not been processed. Call process() first.")
#         return self.relevant_contexts

#     def clear(self) -> None:
#         logger.info("Clearing all processed data")
#         self.questions.clear()
#         self.relevant_contexts.clear()
#         self.corpus.clear()
#         self._dataset_cache = None
#         self._processed = False

#     def save_dataset_to_json(self, file_path: str) -> None:
#         if not self._processed:
#             raise ValueError("Dataset has not been processed. Call process() first.")

#         logger.info(f"Saving dataset to {file_path}")
#         try:
#             data_dict = {
#                 "questions": self.questions,
#                 "relevant_contexts": self.relevant_contexts,
#                 "corpus": self.corpus,
#             }
#             with open(file_path, "w") as f:
#                 json.dump(data_dict, f, indent=4, default=lambda x: x.__dict__, ensure_ascii=False)
#             logger.info(f"Dataset successfully saved to {file_path}")
#         except Exception as e:
#             logger.error(f"Failed to save dataset to {file_path}: {str(e)}")
#             raise

#     def load_dataset_from_json(self, file_path: str) -> None:
#         logger.info(f"Loading dataset from {file_path}")
#         try:
#             with open(file_path, "r") as f:
#                 data_dict = json.load(f)

#             self.questions = data_dict["questions"]
#             self.relevant_contexts = data_dict["relevant_contexts"]
#             self.corpus = data_dict["corpus"]

#             self._dataset_cache = None  # Clear cache to force regeneration
#             self._processed = True
#             logger.info(f"Dataset successfully loaded from {file_path}")
#         except FileNotFoundError:
#             logger.error(f"File not found: {file_path}")
#             raise
#         except json.JSONDecodeError:
#             logger.error(f"Invalid JSON format in file: {file_path}")
#             raise
#         except KeyError as e:
#             logger.error(f"Missing key in loaded data: {str(e)}")
#             raise
#         except Exception as e:
#             logger.error(f"Unexpected error while loading dataset: {str(e)}")
#             raise

In [25]:
from typing import Dict, Generator, List, Any, TypeAlias
from torch.utils.data import Dataset
from sentence_transformers import InputExample
import json
from torch.utils.data import Dataset as TorchDataset
# Type aliases
QuestionID: TypeAlias = str
DocID: TypeAlias = str
QuestionText: TypeAlias = str
DocText: TypeAlias = str
QuestionsDict: TypeAlias = Dict[QuestionID, QuestionText]
RelevantContextsDict: TypeAlias = Dict[QuestionID, List[DocID]]
CorpusDict: TypeAlias = Dict[DocID, DocText]

from loguru import logger
class QADataset(TorchDataset):
    """
    Data class for question-answering dataset.
    Encapsulates questions, relevant contexts, and corpus.
    """

    @logger.catch(reraise=True)
    def __init__(
        self,
        questions: QuestionsDict,
        relevant_contexts: RelevantContextsDict,
        corpus: CorpusDict,
    ):
        self.questions: QuestionsDict = questions
        self.relevant_contexts: RelevantContextsDict = relevant_contexts
        self.corpus: CorpusDict = corpus
        self.validate()
        self._question_ids: List[QuestionID] = list(self.questions.keys())


    @logger.catch(reraise=True)
    def serialize(self, file_path: str) -> None:
        """
        Serialize the dataset to a JSON file with pretty printing.
        """
        dataset = {
            "questions": self.questions,
            "relevant_contexts": self.relevant_contexts,
            "corpus": self.corpus,
        }
        with open(file_path, "w", encoding="utf-8") as f:
            json.dump(dataset, f, indent=4, ensure_ascii=False)
        logger.info(f"Dataset serialized to {file_path}")

    @classmethod
    @logger.catch(reraise=True)
    def deserialize(cls, file_path: str) -> 'QADataset':
        """
        Deserialize the dataset from a JSON file.
        """
        with open(file_path, "r", encoding="utf-8") as f:
            dataset = json.load(f)
        questions = dataset.get("questions", {})
        relevant_contexts = dataset.get("relevant_contexts", {})
        corpus = dataset.get("corpus", {})
        logger.info(f"Dataset deserialized from {file_path}")
        instance = cls(questions, relevant_contexts, corpus)
        instance.validate()
        return instance

    @logger.catch(reraise=True)
    def get_questions(self) -> Dict[str, str]:
        return self.questions

    @logger.catch(reraise=True)
    def get_relevant_contexts(self) -> Dict[str, List[str]]:
        return self.relevant_contexts

    @logger.catch(reraise=True)
    def get_corpus(self) -> Dict[str, str]:
        return self.corpus

    @logger.catch(reraise=True)
    def validate(self) -> None:
        """
        Validate the integrity of the dataset.
        """
        if not self.questions or not self.relevant_contexts or not self.corpus:
            raise ValueError("Dataset components cannot be empty.")
        if set(self.questions.keys()) != set(self.relevant_contexts.keys()):
            raise ValueError("Mismatch between questions and relevant contexts.")
        for doc_ids in self.relevant_contexts.values():
            for doc_id in doc_ids:
                if doc_id not in self.corpus:
                    raise ValueError(f"Document ID {doc_id} in relevant contexts not found in corpus.")
                
    @logger.catch(reraise=True)
    def __len__(self) -> int:
        return len(self._question_ids)

    @logger.catch(reraise=True)
    def __getitem__(self, idx: int) -> InputExample:
        """
        Returns an InputExample containing the question and its corresponding context.
        """
        question_id = self._question_ids[idx]
        question_text = self.questions[question_id]
        doc_ids:Optional[List[DocID]] = self.relevant_contexts.get(question_id)
        if not doc_ids:
            raise ValueError(f"No relevant contexts found for question ID {question_id}")
        doc_id = doc_ids[0]
        context:Optional[DocText] = self.corpus.get(doc_id)
        if context is None:
            raise ValueError(f"Document ID {doc_id} not found in corpus")
        example = InputExample(texts=[question_text, context])
        return example
    def __iter__(self) -> Generator[InputExample, Any, None]:
        for idx in range(len(self)):
            yield self[idx]


In [26]:
import uuid
from typing import List, Dict, Union, Any, Optional
from langchain_core.documents import Document
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI
from tqdm.notebook import tqdm_notebook
import concurrent.futures
from loguru import logger
class DatasetGenerator:
    """
    Generates datasets for question-answering tasks from documents.
    """

    def __init__(self, documents: List[Document], model_name: str = "gpt-4o-mini") -> None:
        if not documents or not all(isinstance(doc, Document) for doc in documents):
            raise ValueError("documents must be a non-empty list of Document objects")
        self.documents: List[Document] = documents
        self.qa_chat_model = ChatOpenAI(name=model_name, temperature=0)
        self.qa_prompt_template: ChatPromptTemplate = ChatPromptTemplate.from_template(
            """
            Given the following context, generate {n_questions} questions based only on the provided context.
            Format the questions as a numbered list:
            1. QUESTION #1
            2. QUESTION #2
            ...
            Context:
            {context}
            """
        )
        self.questions: Dict[str, str] = {}
        self.relevant_contexts: Dict[str, List[str]] = {}
        self.corpus: Dict[str, str] = {}
        self._question_doc_map: Dict[str, str] = {}

    @logger.catch(reraise=True)
    def generate_dataset(
        self, n_questions: int = 2, max_workers: Optional[int] = None
    ) -> QADataset:
        """
        Public method to generate the complete dataset.
        """
        self._generate_questions(n_questions, max_workers)
        self._generate_relevant_contexts()
        self._generate_corpus()
        dataset = QADataset(self.questions, self.relevant_contexts, self.corpus)
        dataset.validate()
        logger.info("Dataset generation complete.")
        return dataset

    @logger.catch(reraise=True)
    def _generate_questions(
        self, n_questions: int = 2, max_workers: Optional[int] = None
    ) -> None:
        """
        Internal method to generate questions from documents.
        """
        if not self.documents:
            raise ValueError("No documents provided for question generation.")
        self.questions.clear()
        self._question_doc_map.clear()

        if max_workers is None or max_workers <= 1:
            for doc in tqdm_notebook(self.documents, desc="Generating questions"):
                self._process_document(doc, n_questions)
        else:
            with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
                futures = {
                    executor.submit(self._process_document, doc, n_questions): doc.metadata["id"]
                    for doc in self.documents
                }
                for future in tqdm_notebook(concurrent.futures.as_completed(futures), total=len(futures), desc="Generating questions"):
                    doc_id = futures[future]
                    try:
                        future.result()
                    except Exception as e:
                        logger.exception(f"Error processing document ID {doc_id}: {e}")
                        raise

    @logger.catch(reraise=True)
    def _process_document(self, doc: Document, n_questions: int) -> None:
        """
        Internal method to generate questions from a single document.
        """
        context = doc.page_content
        doc_id = doc.metadata["id"]
        prompt = self.qa_prompt_template.format(context=context, n_questions=n_questions)
        generated_questions = self.qa_chat_model.invoke(prompt)
        processed_questions = self._process_model_output(
            generated_questions.content, n_questions
        )
        for q in processed_questions:
            question_id = str(uuid.uuid4())
            self.questions[question_id] = q
            self._question_doc_map[question_id] = doc_id

    @staticmethod
    @logger.catch(reraise=True)
    def _process_model_output(
        content: Union[str, List[Union[str, Dict[str, str]]]], n_questions: int
    ) -> List[str]:
        """
        Internal method to process the output from the language model into a list of questions.
        """
        if isinstance(content, str):
            processed_questions: List[str] = DatasetGenerator._parse_questions_from_string(content)
        elif isinstance(content, list):
            processed_questions = DatasetGenerator._parse_questions_from_list(content)
        else:
            raise ValueError(
                f"Unexpected type for generated_questions content: {type(content)}"
            )
        if len(processed_questions) != n_questions:
            raise ValueError(
                f"Expected {n_questions} questions, but got {len(processed_questions)}"
            )
        return processed_questions

    @staticmethod
    @logger.catch(reraise=True)
    def _parse_questions_from_string(content: str) -> List[str]:
        """
        Parse questions from a string output.
        """
        lines: List[str] = content.strip().split("\n")
        questions:List[str] = []
        for line in lines:
            question: Optional[str] = DatasetGenerator._extract_question_from_line(line)
            if question:
                questions.append(question)
        return questions

    @staticmethod
    @logger.catch(reraise=True)
    def _parse_questions_from_list(content_list: List[Union[str, Dict[str, str]]]) -> List[str]:
        """
        Parse questions from a list output.
        """
        questions:List[str] = []
        for item in content_list:
            question:Optional[str]  = DatasetGenerator._extract_question_from_item(item)
            if question:
                questions.append(question)
        return questions

    @staticmethod
    @logger.catch(reraise=True)
    def _extract_question_from_line(line: str) -> Optional[str]:
        """
        Extract question text from a line of text.
        """
        if line.strip():
            question: str = line.split(".", 1)[-1].strip()
            return question
        return None

    @staticmethod
    @logger.catch(reraise=True)
    def _extract_question_from_item(item: Union[str, Dict[str, str]]) -> Optional[str]:
        """
        Extract question text from an item in a list.
        """
        if isinstance(item, str):
            return DatasetGenerator._extract_question_from_line(item)
        elif isinstance(item, dict):
            question = item.get("question", "").strip()
            if question:
                return question
            else:
                raise ValueError("Dictionary item missing 'question' key.")
        else:
            raise ValueError(f"Unexpected item type in content list: {type(item)}")

    @logger.catch(reraise=True)
    def _generate_relevant_contexts(self) -> None:
        """
        Internal method to generate relevant contexts mapping.
        """
        if not self._question_doc_map:
            raise ValueError("No questions have been generated to map relevant contexts.")
        self.relevant_contexts.clear()
        for question_id, doc_id in self._question_doc_map.items():
            self.relevant_contexts[question_id] = [doc_id]

    @logger.catch(reraise=True)
    def _generate_corpus(self) -> None:
        """
        Internal method to generate the corpus.
        """
        if not self.documents:
            raise ValueError("No documents available to generate corpus.")
        self.corpus = {doc.metadata["id"]: doc.page_content for doc in self.documents}

We'll use the function to generate training, validation, and test data with `n_questions=2` for each.

In [27]:
import os
max_workers:Optional[int]=os.cpu_count()

In [28]:
# training_questions, training_relevant_contexts = ### YOUR CODE HERE
# training_questions, training_relevant_contexts = create_questions(training_split_documents, n_questions=2)
from typing import Any, Dict, List


train_generator = DatasetGenerator(training_split_documents)
val_generator = DatasetGenerator(val_split_documents)
test_generator = DatasetGenerator(test_split_documents)

In [29]:


train_dataset: QADataset = train_generator.generate_dataset(n_questions=2,max_workers=max_workers)


Generating questions:   0%|          | 0/720 [00:00<?, ?it/s]

[32m2024-09-19 01:25:07.609[0m | [1mINFO    [0m | [36m__main__[0m:[36mgenerate_dataset[0m:[36m47[0m - [1mDataset generation complete.[0m


In [30]:
train_questions: Dict[str, str] = train_dataset.get_questions()
logger.info(f"Number of training questions: {len(train_questions)}")
train_corpus: Dict[str, str] = train_dataset.get_corpus()
logger.info(f"Number of documents in training corpus: {len(train_corpus)}")
train_relevant_contexts = train_dataset.get_relevant_contexts()

# Get a sample question and its ID
train_sample_question_id, train_sample_question = next(iter(train_questions.items()))
logger.info(f"Sample question: {train_sample_question_id}")

# Get the relevant document ID for this question
train_sample_relevant_doc_id = train_relevant_contexts[train_sample_question_id][0]
logger.info(f"Relevant document ID: {train_sample_relevant_doc_id}")

# Optionally, you can also display the content of the relevant document
train_sample_doc_content = train_corpus[train_sample_relevant_doc_id]
logger.info(f"Sample document content: {train_sample_doc_content[:100]}...")  # Display first 100 characters

[32m2024-09-19 01:25:07.820[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m2[0m - [1mNumber of training questions: 1440[0m
[32m2024-09-19 01:25:07.821[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m4[0m - [1mNumber of documents in training corpus: 720[0m
[32m2024-09-19 01:25:07.822[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m9[0m - [1mSample question: e4437b49-6f1e-4ae8-b2bc-f873536fca85[0m
[32m2024-09-19 01:25:07.824[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m13[0m - [1mRelevant document ID: e7dff837-c329-40a6-be60-d9565405e9ad[0m
[32m2024-09-19 01:25:07.825[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m17[0m - [1mSample document content: evidence and the latest developments in technology;...[0m


In [31]:
# val_questions, val_relevant_contexts = ### YOUR CODE HERE
# val_questions, val_relevant_contexts = create_questions(val_split_documents, n_questions=2)


val_dataset: QADataset = val_generator.generate_dataset(n_questions=2,max_workers=max_workers)


Generating questions:   0%|          | 0/154 [00:00<?, ?it/s]

[32m2024-09-19 01:25:26.941[0m | [1mINFO    [0m | [36m__main__[0m:[36mgenerate_dataset[0m:[36m47[0m - [1mDataset generation complete.[0m


In [32]:
val_questions: Dict[str, str] = val_dataset.get_questions()
logger.info(f"Number of training questions: {len(val_questions)}")
val_corpus: Dict[str, str] = val_dataset.get_corpus()
logger.info(f"Number of documents in validation corpus: {len(val_corpus)}")
val_relevant_contexts = val_dataset.get_relevant_contexts()
val_sample_question_id, val_sample_question = next(iter(val_questions.items()))
logger.info(f"Sample question: {val_sample_question_id}")
val_sample_relevant_doc_id = val_relevant_contexts[val_sample_question_id][0]
logger.info(f"Relevant document ID: {val_sample_relevant_doc_id}")
val_sample_doc_content = val_corpus[val_sample_relevant_doc_id]
logger.info(f"Sample document content: {val_sample_doc_content[:100]}...")

[32m2024-09-19 01:25:27.099[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m2[0m - [1mNumber of training questions: 308[0m
[32m2024-09-19 01:25:27.100[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m4[0m - [1mNumber of documents in validation corpus: 154[0m
[32m2024-09-19 01:25:27.101[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m7[0m - [1mSample question: fbfa14ee-e783-4549-904e-f2c48faecaa6[0m
[32m2024-09-19 01:25:27.102[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m9[0m - [1mRelevant document ID: a2528cf7-9053-495e-be81-dbe64476189d[0m
[32m2024-09-19 01:25:27.103[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m11[0m - [1mSample document content: to facilitate the provider’s obligation to comply with the requirements of this Regulation, when the...[0m


In [33]:
test_dataset:QADataset = test_generator.generate_dataset(n_questions=2,max_workers=max_workers)


Generating questions:   0%|          | 0/155 [00:00<?, ?it/s]

[32m2024-09-19 01:25:46.778[0m | [1mINFO    [0m | [36m__main__[0m:[36mgenerate_dataset[0m:[36m47[0m - [1mDataset generation complete.[0m


In [34]:
# test_questions, test_relevant_contexts = ### YOUR CODE HERE
# test_questions, test_relevant_contexts = create_questions(test_split_documents, n_questions=2)
test_questions: Dict[str, str] = test_dataset.get_questions()
logger.info(f"Number of training questions: {len(test_questions)}")
test_corpus: Dict[str, str] = test_dataset.get_corpus()
logger.info(f"Number of documents in validation corpus: {len(test_corpus)}")
test_relevant_contexts = test_dataset.get_relevant_contexts()
test_sample_question_id, test_sample_question = next(iter(test_questions.items()))
logger.info(f"Sample question: {test_sample_question_id}")
test_sample_relevant_doc_id = test_relevant_contexts[test_sample_question_id][0]
logger.info(f"Relevant document ID: {test_sample_relevant_doc_id}")
test_sample_doc_content = test_corpus[test_sample_relevant_doc_id]
logger.info(f"Sample document content: {test_sample_doc_content[:100]}...")

[32m2024-09-19 01:25:47.001[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m4[0m - [1mNumber of training questions: 310[0m
[32m2024-09-19 01:25:47.002[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m6[0m - [1mNumber of documents in validation corpus: 155[0m
[32m2024-09-19 01:25:47.004[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m9[0m - [1mSample question: ee57e704-11a2-494e-a23e-529d91141a5b[0m
[32m2024-09-19 01:25:47.004[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m11[0m - [1mRelevant document ID: 0647e8da-06dc-418d-a5db-7d51468ae3de[0m
[32m2024-09-19 01:25:47.005[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m13[0m - [1mSample document content: (109) Compliance with the obligations applicable to the providers of general-purpose AI models shoul...[0m


### Reformating and Saving Datasets

Now, we can save our datasets for later use!

> NOTE: If you ran into issues creating the data - you can use the data from the DataRespository. It's simply called: `train_dataset.jsonl`, etc.

In [35]:
import json

# training_corpus = {train_item.metadata["id"] : train_item.page_content for train_item in training_split_documents}

# train_dataset = {
#     "questions" : training_questions,
#     "relevant_contexts" : training_relevant_contexts,
#     "corpus" : training_corpus
# }
training_dataset_path: Path = repo_root_path.joinpath("training_dataset.jsonl")
# train_generator.save_dataset_to_json(str(training_dataset_path))
# with open(f"{str(training_dataset_path)}", "w") as f:
#     json.dump(
#         train_dataset, f, indent=4, default=lambda x: x.__dict__, ensure_ascii=False
#     )
train_dataset.serialize(str(training_dataset_path))
logger.info(f"Training dataset saved to {training_dataset_path}")

[32m2024-09-19 01:25:47.234[0m | [1mINFO    [0m | [36m__main__[0m:[36mserialize[0m:[36m48[0m - [1mDataset serialized to /notebooks/DataRepository/training_dataset.jsonl[0m
[32m2024-09-19 01:25:47.235[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m17[0m - [1mTraining dataset saved to /notebooks/DataRepository/training_dataset.jsonl[0m


In [36]:
# val_corpus = {val_item.metadata["id"] : val_item.page_content for val_item in val_split_documents}

# val_dataset = {
#     "questions" : val_questions,
#     "relevant_contexts" : val_relevant_contexts,
#     "corpus" : val_corpus
# }
# with open("val_dataset.jsonl", "w") as f:
#   json.dump(val_dataset, f)

val_dataset_path: Path = repo_root_path.joinpath("val_dataset.jsonl")
# with open(f"{str(val_dataset_path)}", "w") as f:
#     json.dump(
#         val_dataset, f, indent=4, default=lambda x: x.__dict__, ensure_ascii=False
#     )
# val_generator.save_dataset_to_json(str(val_dataset_path))
val_dataset.serialize(str(val_dataset_path))
logger.info(f"Validation dataset saved to {val_dataset_path}")

[32m2024-09-19 01:25:47.402[0m | [1mINFO    [0m | [36m__main__[0m:[36mserialize[0m:[36m48[0m - [1mDataset serialized to /notebooks/DataRepository/val_dataset.jsonl[0m
[32m2024-09-19 01:25:47.403[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m18[0m - [1mValidation dataset saved to /notebooks/DataRepository/val_dataset.jsonl[0m


In [37]:
# train_corpus = {test_item.metadata["id"] : test_item.page_content for test_item in test_split_documents}

# test_dataset = {
#     "questions" : test_questions,
#     "relevant_contexts" : test_relevant_contexts,
#     "corpus" : train_corpus
# }
# with open("test_dataset.jsonl", "w") as f:
#   json.dump(test_dataset, f)
test_dataset_path: Path = repo_root_path.joinpath("test_dataset.jsonl")

# with open(f"{str(test_dataset_path)}", "w") as f:
#     json.dump(
#         test_dataset, f, indent=4, default=lambda x: x.__dict__, ensure_ascii=False
#     )
# test_generator.save_dataset_to_json(str(test_dataset_path))
test_dataset.serialize(str(test_dataset_path))
logger.info(f"Test dataset saved to {test_dataset_path}")

[32m2024-09-19 01:25:47.523[0m | [1mINFO    [0m | [36m__main__[0m:[36mserialize[0m:[36m48[0m - [1mDataset serialized to /notebooks/DataRepository/test_dataset.jsonl[0m
[32m2024-09-19 01:25:47.524[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m18[0m - [1mTest dataset saved to /notebooks/DataRepository/test_dataset.jsonl[0m


## Task 4: Fine-tuning `snowflake-arctic-embed-m`

Now that we have a dataset, let's grab a `sentence-transformers` Embeddings model!

We'll be using Snowflake's [`snowflake-arctic-embed-m`](https://huggingface.co/Snowflake/snowflake-arctic-embed-m) as a base embeddings model.

It is a well performing embeddings model by itself, but there's a lot of very specific domain terms and vocabulary in our courpus - so lets fine-tune it and see what that can do for us!

<!-- We'll grab some necessary imports from `sentence_transformers` and `torch`. -->

> NOTE: PyTorch (`torch`) is a popular machine learning library - while we don't go very deep into PyTorch it's an incredibly powerful and interesting library! Please read more about it [here](https://pytorch.org/tutorials/beginner/basics/intro.html)!

In [38]:
import json
import requests
import re
from typing import  List, Optional
from loguru import logger
from packaging import version
from sentence_transformers import SentenceTransformer
from torch.torch_version import TorchVersion

class SentenceTransformerFactory:
    """
    Factory class for creating and validating SentenceTransformer models.
    """

    def __init__(
        self, model_name: str, validation_sentences: Optional[List[str]] = None
    ) -> None:
        self.model_name: str = model_name
        self.validation_sentences: Optional[List[str]] = validation_sentences
        self.version_constraints: Dict[str, str] = self._get_version_constraints()
        logger.info(f"Initialized SentenceTransformerFactory with model: {model_name}")

    @logger.catch(reraise=True)
    def create(self) -> SentenceTransformer:
        """
        Create and validate the SentenceTransformer model.
        """
        self._run_all_sanity_checks()
        model = SentenceTransformer(self.model_name)
        logger.info(f"Loaded SentenceTransformer model: {self.model_name}")
        if self.validation_sentences:
            self._validate_model(model)
        return model
    
    @logger.catch(reraise=True)
    def _get_version_constraints(self) -> Dict[str, str]:
        """
        Fetch model metadata from Hugging Face and extract version constraints.
        """
        logger.info(f"Fetching version constraints for model: {self.model_name}")
        base_url: str = f"https://huggingface.co/{self.model_name}/raw/main/"
        files_to_check: List[str] = [
            "modules.json",
            "config.json",
            "config_sentence_transformers.json",
        ]
        version_constraints = {}
        for file_name in files_to_check:
            url = base_url + file_name
            try:
                logger.debug(f"Fetching metadata from: {url}")
                response = requests.get(url)
                response.raise_for_status()
                data = json.loads(response.text)
                logger.debug(f"Successfully fetched metadata from: {url}")

                # Merge version constraints if present
                if "__version__" in data:
                    version_info = data["__version__"]
                    version_constraints.update(version_info)
                    logger.debug(f"Extracted version constraints: {version_info}")
            except requests.HTTPError as e:
                logger.warning(f"Failed to fetch {file_name} from Hugging Face: {e}")
            except json.JSONDecodeError as e:
                logger.warning(f"Failed to parse JSON from {file_name}: {e}")
        if not version_constraints:
            logger.warning("No version constraints found in model metadata.")
        return version_constraints

    @logger.catch(reraise=True)
    def _run_all_sanity_checks(self) -> None:
        """
        Run sanity checks for required packages and versions.
        """
        self._sanity_check_numpy()
        self._sanity_check_cuda()
        self._sanity_check_package("torch")
        self._sanity_check_package("transformers")
        self._sanity_check_package("sentence_transformers")
        logger.info("All sanity checks passed.")


    @logger.catch(reraise=True)
    def _sanity_check_package(self, package_name: str) -> None:
        """
        Check if the installed version of a package meets the required version.
        """
        required_version = self.version_constraints.get(package_name)
        if not required_version:
            logger.info(f"No version constraint specified for {package_name}. Skipping check.")
            return
        try:
            module = __import__(package_name)
            installed_version = module.__version__
            if not self._compare_versions(installed_version, required_version):
                raise ImportError(
                    f"{package_name} version must be >= {required_version}, but found {installed_version}"
                )
            logger.info(f"{package_name} version {installed_version} meets the requirement.")
        except ImportError as e:
            raise ImportError(f"Failed to import {package_name}: {e}")


    @staticmethod
    def _compare_versions(installed_version: str, required_version: str) -> bool:
        """
        Compare two version strings.
        """
        installed_ver = version.parse(installed_version)
        required_ver = version.parse(required_version)
        return installed_ver >= required_ver


    @staticmethod
    @logger.catch(reraise=True)
    def _sanity_check_numpy() -> None:
        import numpy as np
        required_version = "1.16.0"
        installed_version: str = np.__version__
        if version.parse(installed_version) < version.parse(required_version):
            raise ImportError(f"NumPy version must be >= {required_version}, found {installed_version}")
        logger.info(f"NumPy version {installed_version} is adequate.")

    @staticmethod
    @logger.catch(reraise=True)
    def _sanity_check_cuda() -> None:
      import torch
      if not torch.cuda.is_available():
          raise RuntimeError("CUDA is not available. Please ensure CUDA is installed and properly configured.")
      logger.info(f"CUDA version: {torch.version.cuda}")
      logger.info(f"GPU: {torch.cuda.get_device_name(0)}")
    

    @logger.catch(reraise=True)
    def _validate_model(self, model: SentenceTransformer) -> None:
        """
        Validate the model by encoding validation sentences.
        """
        if self.validation_sentences is None:
            logger.info("skipping model encoding validation")
            return
        embeddings = model.encode(self.validation_sentences)
        if embeddings is None or len(embeddings) != len(self.validation_sentences):
            raise ValueError("Model validation failed: Embeddings not generated correctly.")
        logger.info("Model validation successful.")

In [39]:
model_id: str = "Snowflake/snowflake-arctic-embed-m"
validation_sentences: List[str] = [
    "This is a custom validation sentence",
    "Another custom sentence for validation"
]
factory: SentenceTransformerFactory = SentenceTransformerFactory(model_id, validation_sentences)
model: SentenceTransformer = factory.create()

[32m2024-09-19 01:25:47.816[0m | [1mINFO    [0m | [36m__main__[0m:[36m_get_version_constraints[0m:[36m40[0m - [1mFetching version constraints for model: Snowflake/snowflake-arctic-embed-m[0m
[32m2024-09-19 01:25:47.817[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36m_get_version_constraints[0m:[36m51[0m - [34m[1mFetching metadata from: https://huggingface.co/Snowflake/snowflake-arctic-embed-m/raw/main/modules.json[0m
[32m2024-09-19 01:25:47.869[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36m_get_version_constraints[0m:[36m55[0m - [34m[1mSuccessfully fetched metadata from: https://huggingface.co/Snowflake/snowflake-arctic-embed-m/raw/main/modules.json[0m
[32m2024-09-19 01:25:47.870[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36m_get_version_constraints[0m:[36m51[0m - [34m[1mFetching metadata from: https://huggingface.co/Snowflake/snowflake-arctic-embed-m/raw/main/config.json[0m
[32m2024-09-19 01:25:47.937[0m | [34m[1mDEBUG   [0m | 

We're using a toy batch size here to reflect the limited number of examples we have.

> NOTE: It is typical to use a much larger batch size (~64+), hardware permitting.

In [40]:
# BATCH_SIZE = 20

Let's move our dataset into the expected format for training.

In [41]:
# from sentence_transformers import InputExample

# corpus = train_generator.get_corpus()
# queries = train_dataset['questions']
# relevant_docs = train_dataset['relevant_contexts']




# examples: List[InputExample] = []
# for query_id, query in queries.items():
#     doc_id = relevant_docs[query_id][0]
#     text = corpus[doc_id]
#     example = InputExample(texts=[query, text])
#     examples.append(example)

Now we can create a `torch` `DataLoader`!

In [42]:
# from torch.utils.data import DataLoader
# # FIXME: the type is wrong
# # Argument of type "List[InputExample]" cannot be assigned to parameter "dataset" of type "Dataset[T_co@DataLoader]" in function "__init__"
# #   "List[InputExample]" is incompatible with "Dataset[T_co@DataLoader]"PylancereportArgumentType
# # (variable) examples: List[InputExample]
# loader = DataLoader(
#     examples, batch_size=BATCH_SIZE
# )

Next up, we'll prepare our loss function!

Loss is an important part of training, fine-tuning, and more. If you want a deep dive on loss - you can check out our [event on loss!](https://www.youtube.com/watch?v=iB8FWR9aD5Q&t=8s).

The core loss we're using today is called `MultipleNegativesRankingLoss` - you can find more information [here](https://github.com/UKPLab/sentence-transformers/blob/master/sentence_transformers/losses/MultipleNegativesRankingLoss.py).

This is "wrapped" in `MatryoshkaLoss`, which you can read the implementation of [here](https://github.com/UKPLab/sentence-transformers/blob/master/sentence_transformers/losses/MatryoshkaLoss.py).

In [43]:
# from sentence_transformers.losses import MatryoshkaLoss, MultipleNegativesRankingLoss

# matryoshka_dimensions = [768, 512, 256, 128, 64]
# inner_train_loss = MultipleNegativesRankingLoss(model)
# train_loss = MatryoshkaLoss(
#     model, inner_train_loss, matryoshka_dims=matryoshka_dimensions
# )

##### 🏗️ Activity #2:

Both of these losses sound "cool", but what are they - exactly - under the hood?

Why are these losses specifically doing? Please write a short summary of each loss.

> NOTE: This is a course focused on AI Engineering and the application of AI - looking for a hint? Try pasting the code (linked above) into ChatGPT/Claude to write the summary!


###### `MultipleNegativesRankingLoss`:

This loss function is crucial for training sentence embeddings, especially for retrieval tasks. It uses contrastive learning to teach the model how to distinguish between similar and dissimilar pairs of sentences.

How it works (with a detailed example):

Let's say we have a batch of 3 sentence pairs:
- (Q1, D1): "What is AI?" and "Artificial Intelligence is a field of computer science..."
- (Q2, D2): "How does photosynthesis work?" and "Photosynthesis is a process used by plants..."
- (Q3, D3): "Who wrote Romeo and Juliet?" and "Romeo and Juliet was written by William Shakespeare..."

Step 1: Embedding the sentences
First, the model converts each sentence into a numerical vector (embedding). Let's simplify and say we're using 3-dimensional embeddings for this example:

- Q1: [0.5, 0.3, 0.8]
- D1: [0.6, 0.2, 0.7]
- Q2: [-0.2, 0.9, 0.1]
- D2: [-0.3, 0.8, 0.2]
- Q3: [0.1, -0.5, 0.7]
- D3: [0.2, -0.4, 0.6]

Step 2: Computing similarities
The model computes the similarity between Q1 and all other sentences using dot product:

- sim(Q1, D1) = 0.5*0.6 + 0.3*0.2 + 0.8*0.7 = 0.91 (positive pair)
- sim(Q1, Q2) = 0.5*(-0.2) + 0.3*0.9 + 0.8*0.1 = 0.35 (negative)
- sim(Q1, D2) = 0.5*(-0.3) + 0.3*0.8 + 0.8*0.2 = 0.37 (negative)
- sim(Q1, Q3) = 0.5*0.1 + 0.3*(-0.5) + 0.8*0.7 = 0.51 (negative)
- sim(Q1, D3) = 0.5*0.2 + 0.3*(-0.4) + 0.8*0.6 = 0.58 (negative)

Step 3: Applying softmax
These similarities are then passed through a softmax function to convert them into probabilities:

```python
softmax([0.91, 0.35, 0.37, 0.51, 0.58]) ≈ [0.40, 0.12, 0.12, 0.17, 0.19]
```
Step 4: Computing the loss
The loss is computed as the negative log of the probability assigned to the positive pair:

```python
loss = -log(0.40) ≈ 0.92
```

This process is repeated for each query in the batch (Q2 and Q3), and the losses are averaged.

How this encourages the desired behavior:
1. To minimize this loss, the model needs to increase the similarity of Q1 and D1 relative to the other pairs.
2. If it does this successfully, the probability assigned to (Q1, D1) will increase, and the loss will decrease.
3. Simultaneously, this process pushes the embeddings of dissimilar sentences further apart in the vector space.

Over many training iterations, this leads to a model that produces embeddings where similar sentences are closer together and dissimilar sentences are further apart.

###### 2. `MatryoshkaLoss`:

MatryoshkaLoss extends this concept to multiple embedding sizes simultaneously. Here's a more detailed explanation:

How it works (with an example):

Let's say we're using dimensions [`768`, `512`, `256`, `128`, `64`] and we have a sentence: "The quick brown fox jumps over the lazy dog."

Step 1: Generating embeddings
The model produces a 768-dimensional embedding for this sentence. Let's simplify
and say the first 10 dimensions look like this:

```python
[0.1, -0.3, 0.5, 0.2, -0.1, 0.4, 0.6, -0.2, 0.3, 0.7, ...]
```

Step 2: Creating slices
MatryoshkaLoss creates slices of this embedding:
- 768-dim: [0.1, -0.3, 0.5, 0.2, -0.1, 0.4, 0.6, -0.2, 0.3, 0.7, ...] (all 768 dimensions)
- 512-dim: [0.1, -0.3, 0.5, 0.2, -0.1, 0.4, 0.6, -0.2, 0.3, 0.7, ...] (first 512 dimensions)
- 256-dim: [0.1, -0.3, 0.5, 0.2, -0.1, 0.4, 0.6, -0.2, 0.3, 0.7, ...] (first 256 dimensions)
- 128-dim: [0.1, -0.3, 0.5, 0.2, -0.1, 0.4, 0.6, -0.2, 0.3, 0.7, ...] (first 128 dimensions)
- 64-dim: [0.1, -0.3, 0.5, 0.2, -0.1, 0.4, 0.6, -0.2, 0.3, 0.7, ...] (first 64 dimensions)

Step 3: Applying `MultipleNegativesRankingLoss`
For each slice, the `MultipleNegativesRankingLoss` is computed as described earlier.

Step 4: Combining losses

The losses from each slice are combined, often with equal weights:

```python
total_loss = (loss_768 + loss_512 + loss_256 + loss_128 + loss_64) / 5
```

Significance of `matryoshka_dimensions` [`768`, `512`, `256`, `128`, `64`]:

1. 768: This is often the base dimension for models like BERT. It provides the highest information capacity.
2. 512, 256, 128, 64: Each subsequent dimension is roughly half of the previous one. This logarithmic scale provides a good spread of sizes.
3. Range: From 768 to 64, it covers use cases from high-resource environments to very constrained ones.
4. Powers of 2: These dimensions are computationally efficient in many systems.

You could choose different numbers based on your specific needs. For example:
- [1024, 512, 256, 128] for larger base embeddings
- [512, 384, 256, 192, 128] for more granularity in the middle range

The combination of MultipleNegativesRankingLoss and MatryoshkaLoss allows the fine-tuning process to:
1. Improve embedding quality for retrieval tasks across multiple dimensions.
2. Create a single, flexible model adaptable to various computational environments.
3. Potentially enhance generalization by maintaining information across different embedding sizes.

This approach is particularly valuable when you need a versatile model that can be deployed in various settings or used for different downstream tasks with varying computational constraints.

By training with these losses, you're essentially creating a "Swiss Army knife" of an embedding model – one tool that can adapt to many different situations and requirements.

Now we can set-up our evaluator.

> NOTE: Due to the formatting of our dataset - this is all we have to do!

In [44]:
# from sentence_transformers.evaluation import InformationRetrievalEvaluator

# corpus = val_dataset['corpus']
# queries = val_dataset['questions']
# relevant_docs = val_dataset['relevant_contexts']

# evaluator = InformationRetrievalEvaluator(queries, corpus, relevant_docs)

We'll train this model for 5 epochs, though you could increase this number if we had a significant amount more data.

In [45]:
# EPOCHS = 5

It's training time!

> NOTE: We're manually defining a warm-up period here - this is just to provide a smooth ramp into our training!

In [46]:
# import transformers
# import accelerate

# print(transformers.__version__)
# print(accelerate.__version__)

In [47]:
# warmup_steps = int(len(loader) * EPOCHS * 0.1)

# model.fit(
#     train_objectives=[(loader, train_loss)],
#     epochs=EPOCHS,
#     warmup_steps=warmup_steps,
#     output_path='finetuned_arctic',
#     show_progress_bar=True,
#     evaluator=evaluator,
#     evaluation_steps=50,
# )

In [48]:
# from typing import Dict, List, Optional, Union,Iterator
# from sentence_transformers import SentenceTransformer, InputExample
# from sentence_transformers.losses import MatryoshkaLoss, MultipleNegativesRankingLoss
# from sentence_transformers.evaluation import InformationRetrievalEvaluator
# # TODO: having two datasets is dumb
# from torch.utils.data import DataLoader, Dataset as TorchDataset
# from datasets import Dataset as HuggingFaceDataset
# from loguru import logger

# class InputExampleDataset(TorchDataset):
#     def __init__(self, examples):
#         self.examples = [InputExample(texts=[ex['query'], ex['context']]) if isinstance(ex, dict) else ex for ex in examples]

#     def __len__(self):
#         return len(self.examples)

#     def __getitem__(self, idx):
#         return self.examples[idx]

#     def __iter__(self):
#         return iter(self.examples)
# class EmbeddingFinetuner:
#     def __init__(
#         self,
#         model: SentenceTransformer,
#         train_dataset: HuggingFaceDataset,
#         val_dataset: HuggingFaceDataset,
#         matryoshka_dims: List[int] = [768, 512, 256, 128, 64],
#         batch_size: int = 20,
#         epochs: int = 5
#     ):
#         self.model = model
#         self.matryoshka_dims = matryoshka_dims
#         self.batch_size = batch_size
#         self.epochs = epochs
#         self.train_dataloader: Optional[DataLoader] = None
#         self.train_loss: Optional[MatryoshkaLoss] = None
#         self.evaluator: Optional[InformationRetrievalEvaluator] = None

#         self.train_dataset = train_dataset
#         self.val_dataset = val_dataset

#         logger.info(f"EmbeddingFinetuner initialized with model: {model.__class__.__name__}")
#     def _process_dataset(self, dataset: HuggingFaceDataset) -> InputExampleDataset:
#         logger.info(f"Processing dataset with {len(dataset)} examples")
#         processed_examples = []
#         for example in dataset:
#             processed_example = self._create_input_example(example)
#             if processed_example is not None:
#                 processed_examples.append(processed_example)
        
#         logger.info(f"Processed dataset now has {len(processed_examples)} examples")
#         return InputExampleDataset(processed_examples)
#     @staticmethod
#     def _create_input_example(example: Union[Dict[str, Union[str, List[str]]], str]) -> Optional[InputExample]:
#         if isinstance(example, dict):
#             if 'doc_id' in example:
#                 # Case when doc_id exists (unprocessed input)
#                 query = example.get('questions')
#                 relevant_contexts = example.get('relevant_contexts')
#                 corpus = example.get('corpus')

#                 if not isinstance(query, str) or not isinstance(relevant_contexts, list) or not isinstance(corpus, dict):
#                     logger.error(f"Unexpected types: query {type(query)}, relevant_contexts {type(relevant_contexts)}, corpus {type(corpus)}. Expected str, list, dict respectively.")
#                     return None

#                 if not relevant_contexts:
#                     logger.error("relevant_contexts is empty.")
#                     return None

#                 doc_id = relevant_contexts[0]
#                 text = corpus.get(doc_id, "")
#                 return InputExample(texts=[query, text])
#             else:
#                 # Case for already processed input
#                 query = example.get('query')
#                 context = example.get('context')
                
#                 if not isinstance(query, str) or not isinstance(context, str):
#                     logger.error(f"Unexpected types: query {type(query)}, context {type(context)}. Expected str for both.")
#                     return None
                
#                 return InputExample(texts=[query, context])
#         elif isinstance(example, str):
#             # If the example is a string, assume it's the query and there's no context
#             logger.warning("Received a string example. Treating it as a query with no context.")
#             return InputExample(texts=[example, ""])
#         else:
#             logger.error(f"Unexpected example type: {type(example)}. Expected dict or str.")
#             return None
#     def _prepare_training_data(self) -> None:
#         logger.info("Preparing training data")
        
#         corpus = self.train_dataset['corpus']
#         queries = self.train_dataset['questions']
#         relevant_docs = self.train_dataset['relevant_contexts']
        
#         processed_examples = []
#         for query_id, query in queries.items():
#             doc_id = relevant_docs[query_id][0]  # Assuming we're using the first relevant doc
#             text = corpus[doc_id]
#             processed_examples.append(InputExample(texts=[query, text]))
        
#         dataset = InputExampleDataset(processed_examples)
#         self.train_dataloader = DataLoader(
#             dataset, batch_size=self.batch_size, shuffle=True
#         )
        
#         logger.debug(f"Training data prepared. Total examples: {len(processed_examples)}")
#     def _prepare_evaluator(self) -> None:
#         logger.info("Preparing evaluator")
        
#         processed_dataset = self._process_dataset(self.val_dataset)
#         queries = {}
#         corpus = {}
        
#         for i, example in enumerate(processed_dataset):
#             if isinstance(example, InputExample):
#                 queries[str(i)] = example.texts[0]
#                 corpus[str(i)] = example.texts[1]
#             elif isinstance(example, dict):
#                 queries[str(i)] = example['query']
#                 corpus[str(i)] = example['context']
#             else:
#                 logger.warning(f"Unexpected type in processed_dataset: {type(example)}. Skipping.")
        
#         relevant_docs = {str(i): {str(i)} for i in range(len(queries))}
        
#         self.evaluator = InformationRetrievalEvaluator(queries, corpus, relevant_docs)
#         logger.debug(f"Created InformationRetrievalEvaluator with {len(queries)} queries and {len(corpus)} documents")
#     def _prepare_loss_function(self) -> None:
#         logger.info("Preparing loss function")
#         inner_train_loss = MultipleNegativesRankingLoss(self.model)
#         self.train_loss = MatryoshkaLoss(self.model, inner_train_loss, matryoshka_dims=self.matryoshka_dims)
#         logger.debug(f"Created MatryoshkaLoss with dimensions: {self.matryoshka_dims}")

#     def train(self, output_path: str) -> None:
#         logger.info(f"Starting training process. Output path: {output_path}")
#         self._prepare_training_data()
#         self._prepare_loss_function()
#         self._prepare_evaluator()

#         if not self.train_dataloader or not self.train_loss or not self.evaluator:
#             logger.error("Training components not properly initialized")
#             raise ValueError("Training components not properly initialized")

#         warmup_steps = int(len(self.train_dataloader) * self.epochs * 0.1)
#         logger.debug(f"Warmup steps: {warmup_steps}")

#         self.model.fit(
#             train_objectives=[(self.train_dataloader, self.train_loss)],
#             epochs=self.epochs,
#             warmup_steps=warmup_steps,
#             output_path=output_path,
#             show_progress_bar=True,
#             evaluator=self.evaluator,
#             evaluation_steps=50,
#         )
#         logger.info("Training completed")

In [49]:
from typing import  List, Optional
from sentence_transformers import SentenceTransformer
from sentence_transformers.losses import MatryoshkaLoss, MultipleNegativesRankingLoss
from sentence_transformers.evaluation import InformationRetrievalEvaluator
from torch.utils.data import DataLoader
from loguru import logger

class EmbeddingFinetuner:
    """
    Fine-tunes an embedding model with the provided training data.
    """

    def __init__(
        self,
        model: SentenceTransformer,
        train_dataset: QADataset,
        val_dataset: QADataset,
        matryoshka_dims: List[int] = [768, 512, 256, 128, 64],
        batch_size: int = 20,
        epochs: int = 5,
    ):
        self.model: SentenceTransformer = model
        self.matryoshka_dims: List[int] = matryoshka_dims
        self.batch_size: int = batch_size
        self.epochs: int = epochs
        self.train_dataloader: Optional[DataLoader] = None
        self.train_loss: Optional[MatryoshkaLoss] = None
        self.evaluator: Optional[InformationRetrievalEvaluator] = None
        self.train_dataset: QADataset = train_dataset
        self.val_dataset: QADataset = val_dataset

    @logger.catch(reraise=True)
    def _prepare_training_data(self) -> None:
        """
        Prepare the training data loader.
        """
        if not isinstance(self.train_dataset, Dataset):
          raise TypeError("train_dataset must be an instance of torch.utils.data.Dataset")
        self.train_dataloader = DataLoader(
            self.train_dataset, shuffle=True, batch_size=self.batch_size
        )
        logger.info(f"Prepared training data with {len(self.train_dataset)} examples.")

    @logger.catch(reraise=True)
    def _prepare_evaluator(self) -> None:
        """
        Prepare the evaluator using the validation dataset.
        """
        queries = self.val_dataset.get_questions()
        corpus = self.val_dataset.get_corpus()
        relevant_contexts = self.val_dataset.get_relevant_contexts()
        relevant_docs = {qid: set(doc_ids) for qid, doc_ids in relevant_contexts.items()}
        self.evaluator = InformationRetrievalEvaluator(
            queries=queries,
            corpus=corpus,
            relevant_docs=relevant_docs,
            show_progress_bar=True,
        )
        logger.info("Evaluator prepared.")

    @logger.catch(reraise=True)
    def _prepare_loss_function(self) -> None:
        """
        Prepare the loss function for training.
        """
        inner_train_loss = MultipleNegativesRankingLoss(self.model)
        self.train_loss = MatryoshkaLoss(
            self.model, inner_train_loss, matryoshka_dims=self.matryoshka_dims
        )
        logger.info("Loss function prepared.")
  
    @logger.catch(reraise=True)
    def train(self, output_path: str) -> None:
        """
        Train the embedding model.
        """
        self._prepare_training_data()
        self._prepare_loss_function()
        self._prepare_evaluator()
        if not self.train_dataloader or not self.train_loss or not self.evaluator:
            raise ValueError("Training components not properly initialized")
        warmup_steps = int(len(self.train_dataloader) * self.epochs * 0.1)
        self.model.fit(
            train_objectives=[(self.train_dataloader, self.train_loss)],
            epochs=self.epochs,
            warmup_steps=warmup_steps,
            evaluator=self.evaluator,
            evaluation_steps=50,
            output_path=output_path,
            show_progress_bar=True,
        )
        logger.info("Model training complete.")


In [50]:
finetuned_model_output_path:str="finetuned_arctic"

In [51]:
finetuner = EmbeddingFinetuner(
  model=model,
  train_dataset=train_dataset,
  val_dataset=val_dataset,
)

finetuner.train(output_path=finetuned_model_output_path)
model: SentenceTransformer = finetuner.model

[32m2024-09-19 01:25:54.182[0m | [1mINFO    [0m | [36m__main__[0m:[36m_prepare_training_data[0m:[36m42[0m - [1mPrepared training data with 1440 examples.[0m
[32m2024-09-19 01:25:54.184[0m | [1mINFO    [0m | [36m__main__[0m:[36m_prepare_loss_function[0m:[36m70[0m - [1mLoss function prepared.[0m
[32m2024-09-19 01:25:54.185[0m | [1mINFO    [0m | [36m__main__[0m:[36m_prepare_evaluator[0m:[36m59[0m - [1mEvaluator prepared.[0m


Epoch:   0%|          | 0/5 [00:00<?, ?it/s]

Iteration:   0%|          | 0/72 [00:00<?, ?it/s]

Batches:   0%|          | 0/10 [00:00<?, ?it/s]


[A
Corpus Chunks: 100%|██████████| 1/1 [00:01<00:00,  1.44s/it]


Batches:   0%|          | 0/10 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:01<00:00,  1.46s/it]


Iteration:   0%|          | 0/72 [00:00<?, ?it/s]

Batches:   0%|          | 0/10 [00:00<?, ?it/s]


[A
Corpus Chunks: 100%|██████████| 1/1 [00:01<00:00,  1.41s/it]


Batches:   0%|          | 0/10 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:01<00:00,  1.39s/it]


Iteration:   0%|          | 0/72 [00:00<?, ?it/s]

Batches:   0%|          | 0/10 [00:00<?, ?it/s]


[A
Corpus Chunks: 100%|██████████| 1/1 [00:01<00:00,  1.42s/it]


Batches:   0%|          | 0/10 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:01<00:00,  1.41s/it]


Iteration:   0%|          | 0/72 [00:00<?, ?it/s]

Batches:   0%|          | 0/10 [00:00<?, ?it/s]


[A
Corpus Chunks: 100%|██████████| 1/1 [00:01<00:00,  1.41s/it]


Batches:   0%|          | 0/10 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:01<00:00,  1.42s/it]


Iteration:   0%|          | 0/72 [00:00<?, ?it/s]

Batches:   0%|          | 0/10 [00:00<?, ?it/s]


[A
Corpus Chunks: 100%|██████████| 1/1 [00:01<00:00,  1.42s/it]


Batches:   0%|          | 0/10 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:01<00:00,  1.43s/it]
[32m2024-09-19 01:29:49.020[0m | [1mINFO    [0m | [36m__main__[0m:[36mtrain[0m:[36m92[0m - [1mModel training complete.[0m


## Task 5: Evaluating our Retriever

Now that we have fine-tuned our retriever - let's see if it's worthwhile!


In [52]:
# import pandas as pd

# from langchain_community.vectorstores import FAISS
# from langchain_openai.embeddings import OpenAIEmbeddings
# from langchain_core.documents import Document

<!-- Now we'll define a function that will help us evaluate our retrieval process.

> NOTE: We're assuming 1 correct document in a "hit". -->

In [53]:
# from tqdm.notebook import tqdm_notebook
# def evaluate_openai(
#     dataset,
#     embed_model,
#     top_k=5,
#     verbose=False,
# ):
#   corpus = dataset['corpus']
#   questions = dataset['questions']
#   relevant_docs = dataset['relevant_contexts']
#   documents = [Document(page_content=content, metadata={"id": doc_id}) for doc_id, content in corpus.items()]
#   vectorstore = FAISS.from_documents(documents, embed_model)

#   retriever = vectorstore.as_retriever(search_kwargs={"k": top_k})

#   eval_results = []
#   for id, question in  tqdm_notebook(questions.items(), desc="Evaluating  retrieval"):
#     retrieved_nodes = retriever.invoke(question)
#     retrieved_ids = [node.metadata["id"] for node in retrieved_nodes]
#     expected_id = relevant_docs[id][0]
#     is_hit = expected_id in retrieved_ids
#     eval_results.append({"id": id, "question": question, "expected_id": expected_id, "is_hit": is_hit})

#   return eval_results

We start with an Evaluator class to help with checking results

In [54]:
from dataclasses import dataclass
from typing import List, Union, Set, Dict, Optional

from langchain_core.vectorstores.base import VectorStoreRetriever
import pandas as pd
from tqdm.notebook import tqdm_notebook

from langchain_core.documents import Document
from langchain_community.vectorstores import FAISS
from loguru import logger
from langchain_openai.embeddings import OpenAIEmbeddings
from langchain_huggingface import HuggingFaceEmbeddings

@dataclass
class EvaluationResult:
    question_id: str
    question: str
    retrieved_ids: List[str]
    expected_ids: Set[str]
    is_hit: bool
    rank: Optional[int] = None  # Rank of the first relevant document
@dataclass
class EvaluationMetrics:
    model_name: str
    hit_rate: float
    mrr: float
    total_questions: int

class Evaluator:
    """
    Provides methods to evaluate embedding models on a given QADataset.
    Stores evaluation results in the class state.
    """

    def __init__(
        self,
        dataset: QADataset,
        embed_model: Union[OpenAIEmbeddings, HuggingFaceEmbeddings],
        top_k: int = 5,
    ):
        self.dataset = dataset
        self.embed_model: Union[OpenAIEmbeddings, HuggingFaceEmbeddings] = embed_model
        self.top_k = top_k
        self.retriever: Optional[VectorStoreRetriever] = None
        self.evaluation_results: List[EvaluationResult] = []
        self.hit_rate: Optional[float] = None
        self.mrr: Optional[float] = None

    @logger.catch(reraise=True)
    def evaluate(self) -> None:
        """
        Evaluate the embedding model on the dataset.
        Stores results in the class state.
        """
        self.prepare_retriever()
        self.evaluate_all_questions()
        self.compute_hit_rate()
        self.compute_mean_reciprocal_rank()
        logger.info(f"Model evaluation complete. Hit Rate: {self.hit_rate}, MRR: {self.mrr}")

    @logger.catch(reraise=True)
    def prepare_retriever(self) -> None:
        """
        Prepare the retriever from the dataset corpus and embedding model.
        """
        logger.info("Preparing vector store and retriever.")
        corpus = self.dataset.get_corpus()
        documents = [
            Document(page_content=content, metadata={"id": doc_id})
            for doc_id, content in corpus.items()
        ]
        vectorstore = FAISS.from_documents(documents, self.embed_model)
        self.retriever = vectorstore.as_retriever(search_kwargs={"k": self.top_k})
        logger.info("Retriever prepared.")

    @logger.catch(reraise=True)
    def evaluate_all_questions(self) -> None:
        """
        Evaluate all questions in the dataset using the retriever.
        Stores results in the class state.
        """
        logger.info("Starting evaluation of all questions.")
        questions = self.dataset.get_questions()
        relevant_contexts = self.dataset.get_relevant_contexts()
        self.evaluation_results = []

        for question_id, question in tqdm_notebook(questions.items(), desc="Evaluating model"):
            result = self.evaluate_question(
                question_id, question, relevant_contexts
            )
            self.evaluation_results.append(result)
        logger.info("Evaluation of all questions completed.")

    @logger.catch(reraise=True)
    def evaluate_question(
        self,
        question_id: str,
        question: str,
        relevant_contexts: Dict[str, List[str]],
    ) -> EvaluationResult:
        """
        Evaluate a single question.
        Returns an EvaluationResult object.
        """
        logger.trace(f"Evaluating question ID: {question_id}")
        if self.retriever is None:
            raise ValueError(f"Retriever is not initialized for question ID: {question_id}")
        retrieved_docs: List[Document] = self.retriever.invoke(question)
        retrieved_ids = [doc.metadata["id"] for doc in retrieved_docs]
        expected_ids = set(relevant_contexts[question_id])
        is_hit = bool(expected_ids.intersection(retrieved_ids))
        # Find rank of first relevant document
        rank = None
        for idx, doc_id in enumerate(retrieved_ids, start=1):
            if doc_id in expected_ids:
                rank = idx
                break
        result = EvaluationResult(
            question_id=question_id,
            question=question,
            retrieved_ids=retrieved_ids,
            expected_ids=expected_ids,
            is_hit=is_hit,
            rank=rank,
        )
        logger.trace(f"Evaluation Result for question ID {question_id}: {result}")
        return result

    def compute_hit_rate(self) -> None:
        """
        Compute hit rate from the evaluation results.
        Stores the result in the class state.
        """
        hits = sum(1 for result in self.evaluation_results if result.is_hit)
        total = len(self.evaluation_results)
        self.hit_rate = hits / total if total > 0 else 0.0
        logger.info(f"Computed hit rate: {self.hit_rate}")

    def compute_mean_reciprocal_rank(self) -> None:
        """
        Compute Mean Reciprocal Rank (MRR) from the evaluation results.
        Stores the result in the class state.
        """
        reciprocal_ranks = []
        for result in self.evaluation_results:
            if result.rank:
                reciprocal_ranks.append(1 / result.rank)
            else:
                reciprocal_ranks.append(0.0)
        total = len(reciprocal_ranks)
        self.mrr = sum(reciprocal_ranks) / total if total > 0 else 0.0
        logger.info(f"Computed Mean Reciprocal Rank (MRR): {self.mrr}")

    def generate_evaluation_report(self) -> pd.DataFrame:
        """
        Generate a detailed evaluation report as a Pandas DataFrame.
        """
        logger.info("Generating evaluation report.")
        data = []
        for result in self.evaluation_results:
            data.append({
                'question_id': result.question_id,
                'question': result.question,
                'is_hit': result.is_hit,
                'rank': result.rank,
                'retrieved_ids': result.retrieved_ids,
                'expected_ids': list(result.expected_ids),
            })
        df = pd.DataFrame(data)
        logger.info("Evaluation report generated.")
        return df
    @staticmethod
    @logger.catch(reraise=True)
    def compare_evaluations(
        evaluation_data: List[Tuple[str, List[EvaluationResult]]]
    ) -> pd.DataFrame:
        """
        Compare multiple models based on their EvaluationResult lists.
        Returns a Pandas DataFrame containing comparison metrics.
        """
        logger.info("Starting comparison of evaluation results.")
        metrics_list: List[EvaluationMetrics] = []

        for model_name, results in evaluation_data:
            # Check if results is of type List[EvaluationResult]
            # if not isinstance(results, list) or not all(isinstance(result, EvaluationResult) for result in results):
            if not isinstance(results, list) :
                print(type(results[0]).__name__)
                raise ValueError(f"Invalid results type for model {model_name}. Expected List[EvaluationResult]. Actual {type(results).__name__}")
            
            hits = sum(1 for result in results if result.is_hit)
            total = len(results)
            hit_rate = hits / total if total > 0 else 0.0

            reciprocal_ranks = []
            for result in results:
                if result.rank:
                    reciprocal_ranks.append(1 / result.rank)
                else:
                    reciprocal_ranks.append(0.0)
            mrr = sum(reciprocal_ranks) / total if total > 0 else 0.0

            metrics = EvaluationMetrics(
                model_name=model_name,
                hit_rate=hit_rate,
                mrr=mrr,
                total_questions=total,
            )
            metrics_list.append(metrics)
            logger.info(
                f"Computed metrics for {model_name} - Hit Rate: {hit_rate}, MRR: {mrr}"
            )

        # Convert metrics into a DataFrame
        comparison_df = pd.DataFrame([vars(metric) for metric in metrics_list])
        logger.info("Comparison DataFrame created.")
        return comparison_df



All that's left to do is evaluate, we'll evaluate our model against:

1. OpenAI's closed source `text-embedding-3-small`
2. The base non-fine-tuned version of `Snowflake/snowflake-arctic-embed-m`.

Let's see how it stacks up!

### `text-embedding-3-small`

In [55]:
# te3_openai = OpenAIEmbeddings(model="text-embedding-3-small")
# te3_results = evaluate_openai(test_dataset, te3_openai)
from typing import List
from langchain_openai.embeddings import OpenAIEmbeddings
te3_openai = OpenAIEmbeddings(model="text-embedding-3-small")
te3_openai_evaluator = Evaluator(test_dataset, te3_openai, top_k=5)
te3_openai_evaluator.evaluate()

[32m2024-09-19 01:29:50.792[0m | [1mINFO    [0m | [36m__main__[0m:[36mprepare_retriever[0m:[36m66[0m - [1mPreparing vector store and retriever.[0m
[32m2024-09-19 01:29:52.050[0m | [1mINFO    [0m | [36m__main__[0m:[36mprepare_retriever[0m:[36m74[0m - [1mRetriever prepared.[0m
[32m2024-09-19 01:29:52.051[0m | [1mINFO    [0m | [36m__main__[0m:[36mevaluate_all_questions[0m:[36m82[0m - [1mStarting evaluation of all questions.[0m


Evaluating model:   0%|          | 0/310 [00:00<?, ?it/s]

[32m2024-09-19 01:31:08.846[0m | [1mINFO    [0m | [36m__main__[0m:[36mevaluate_all_questions[0m:[36m92[0m - [1mEvaluation of all questions completed.[0m
[32m2024-09-19 01:31:08.848[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_hit_rate[0m:[36m137[0m - [1mComputed hit rate: 0.9870967741935484[0m
[32m2024-09-19 01:31:08.849[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_mean_reciprocal_rank[0m:[36m152[0m - [1mComputed Mean Reciprocal Rank (MRR): 0.9381182795698925[0m
[32m2024-09-19 01:31:08.850[0m | [1mINFO    [0m | [36m__main__[0m:[36mevaluate[0m:[36m59[0m - [1mModel evaluation complete. Hit Rate: 0.9870967741935484, MRR: 0.9381182795698925[0m


In [56]:
te3_results: List[EvaluationResult] = te3_openai_evaluator.evaluation_results

logger.info(f"OpenAI model hit rate: {te3_openai_evaluator.hit_rate}")
logger.info(f"OpenAI model MRR: {te3_openai_evaluator.mrr}")

te3_report = te3_openai_evaluator.generate_evaluation_report()
print(te3_report.head())

[32m2024-09-19 01:31:09.042[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m3[0m - [1mOpenAI model hit rate: 0.9870967741935484[0m
[32m2024-09-19 01:31:09.043[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m4[0m - [1mOpenAI model MRR: 0.9381182795698925[0m
[32m2024-09-19 01:31:09.044[0m | [1mINFO    [0m | [36m__main__[0m:[36mgenerate_evaluation_report[0m:[36m158[0m - [1mGenerating evaluation report.[0m
[32m2024-09-19 01:31:09.050[0m | [1mINFO    [0m | [36m__main__[0m:[36mgenerate_evaluation_report[0m:[36m170[0m - [1mEvaluation report generated.[0m


                            question_id  \
0  ee57e704-11a2-494e-a23e-529d91141a5b   
1  7b478e5a-3991-46f7-9935-726a76ee7b3f   
2  9867205c-a82f-49f4-8fd5-713ab8d8bf14   
3  682c1384-2df1-4788-896a-45629b573a34   
4  9c287439-c95d-4d88-b82c-27ed2833a61d   

                                            question  is_hit  rank  \
0  How does the size of the provider impact the c...    True   1.0   
1  What are the simplified ways of compliance for...    True   1.0   
2  How can AI systems be designed to ensure equit...    True   1.0   
3  What ethical considerations should be taken in...    True   1.0   
4  What information must be included in the EU de...    True   1.0   

                                       retrieved_ids  \
0  [0647e8da-06dc-418d-a5db-7d51468ae3de, 358f7e5...   
1  [0647e8da-06dc-418d-a5db-7d51468ae3de, 358f7e5...   
2  [096bff58-c010-431a-a0b9-b8661bd551b4, 4af04c0...   
3  [096bff58-c010-431a-a0b9-b8661bd551b4, 3b0c926...   
4  [bb37130f-2f53-453d-bb06-f47baded20c7

In [57]:
# te3_results_df = pd.DataFrame(te3_results)

In [58]:
te3_hit_rate = te3_report["is_hit"].mean()
te3_hit_rate

0.9870967741935484

### `Snowflake/snowflake-arctic-embed-m` (base)

In [59]:
huggingface_embeddings = HuggingFaceEmbeddings(model_name="Snowflake/snowflake-arctic-embed-m")
arctic_evaluator = Evaluator(test_dataset, huggingface_embeddings, top_k=5)
arctic_evaluator.evaluate()




[32m2024-09-19 01:31:10.064[0m | [1mINFO    [0m | [36m__main__[0m:[36mprepare_retriever[0m:[36m66[0m - [1mPreparing vector store and retriever.[0m
[32m2024-09-19 01:31:11.706[0m | [1mINFO    [0m | [36m__main__[0m:[36mprepare_retriever[0m:[36m74[0m - [1mRetriever prepared.[0m
[32m2024-09-19 01:31:11.707[0m | [1mINFO    [0m | [36m__main__[0m:[36mevaluate_all_questions[0m:[36m82[0m - [1mStarting evaluation of all questions.[0m


Evaluating model:   0%|          | 0/310 [00:00<?, ?it/s]

[32m2024-09-19 01:31:16.449[0m | [1mINFO    [0m | [36m__main__[0m:[36mevaluate_all_questions[0m:[36m92[0m - [1mEvaluation of all questions completed.[0m
[32m2024-09-19 01:31:16.451[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_hit_rate[0m:[36m137[0m - [1mComputed hit rate: 0.5516129032258065[0m
[32m2024-09-19 01:31:16.452[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_mean_reciprocal_rank[0m:[36m152[0m - [1mComputed Mean Reciprocal Rank (MRR): 0.3795698924731183[0m
[32m2024-09-19 01:31:16.452[0m | [1mINFO    [0m | [36m__main__[0m:[36mevaluate[0m:[36m59[0m - [1mModel evaluation complete. Hit Rate: 0.5516129032258065, MRR: 0.3795698924731183[0m


In [60]:
# arctic_embed_m_results_df = pd.DataFrame(arctic_embed_m_results)
from typing import List

arctic_results: List[EvaluationResult] = arctic_evaluator.evaluation_results
logger.info(f"Pre-trained Arctic model hit rate: {arctic_evaluator.hit_rate}")
logger.info(f"Pre-trained Arctic model MRR: {arctic_evaluator.mrr}")

arctic_report = arctic_evaluator.generate_evaluation_report()
print(arctic_report.head())

[32m2024-09-19 01:31:16.653[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m5[0m - [1mPre-trained Arctic model hit rate: 0.5516129032258065[0m
[32m2024-09-19 01:31:16.654[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m6[0m - [1mPre-trained Arctic model MRR: 0.3795698924731183[0m
[32m2024-09-19 01:31:16.655[0m | [1mINFO    [0m | [36m__main__[0m:[36mgenerate_evaluation_report[0m:[36m158[0m - [1mGenerating evaluation report.[0m
[32m2024-09-19 01:31:16.657[0m | [1mINFO    [0m | [36m__main__[0m:[36mgenerate_evaluation_report[0m:[36m170[0m - [1mEvaluation report generated.[0m


                            question_id  \
0  ee57e704-11a2-494e-a23e-529d91141a5b   
1  7b478e5a-3991-46f7-9935-726a76ee7b3f   
2  9867205c-a82f-49f4-8fd5-713ab8d8bf14   
3  682c1384-2df1-4788-896a-45629b573a34   
4  9c287439-c95d-4d88-b82c-27ed2833a61d   

                                            question  is_hit  rank  \
0  How does the size of the provider impact the c...   False   NaN   
1  What are the simplified ways of compliance for...   False   NaN   
2  How can AI systems be designed to ensure equit...   False   NaN   
3  What ethical considerations should be taken in...   False   NaN   
4  What information must be included in the EU de...   False   NaN   

                                       retrieved_ids  \
0  [e1e04ed1-5736-4de0-ae9d-8ab290a50f74, 694b37a...   
1  [e1e04ed1-5736-4de0-ae9d-8ab290a50f74, 5ee2854...   
2  [e1e04ed1-5736-4de0-ae9d-8ab290a50f74, ddc2e1f...   
3  [e1e04ed1-5736-4de0-ae9d-8ab290a50f74, 5ee2854...   
4  [e1e04ed1-5736-4de0-ae9d-8ab290a50f74

In [61]:
# arctic_embed_m_hit_rate = arctic_embed_m_results_df["is_hit"].mean()
# arctic_embed_m_hit_rate

arctic_embed_m_hit_rate=arctic_report["is_hit"].mean()
arctic_embed_m_hit_rate

0.5516129032258065

### `Snowflake/snowflake-arctic-embed-m` (fine-tuned)

In [62]:
finetune_embeddings = HuggingFaceEmbeddings(model_name=finetuned_model_output_path)
finetune_evaluator = Evaluator(test_dataset, finetune_embeddings, top_k=5)
finetune_evaluator.evaluate()




Some weights of BertModel were not initialized from the model checkpoint at finetuned_arctic and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
[32m2024-09-19 01:31:17.331[0m | [1mINFO    [0m | [36m__main__[0m:[36mprepare_retriever[0m:[36m66[0m - [1mPreparing vector store and retriever.[0m
[32m2024-09-19 01:31:18.976[0m | [1mINFO    [0m | [36m__main__[0m:[36mprepare_retriever[0m:[36m74[0m - [1mRetriever prepared.[0m
[32m2024-09-19 01:31:18.978[0m | [1mINFO    [0m | [36m__main__[0m:[36mevaluate_all_questions[0m:[36m82[0m - [1mStarting evaluation of all questions.[0m


Evaluating model:   0%|          | 0/310 [00:00<?, ?it/s]

[32m2024-09-19 01:31:23.983[0m | [1mINFO    [0m | [36m__main__[0m:[36mevaluate_all_questions[0m:[36m92[0m - [1mEvaluation of all questions completed.[0m
[32m2024-09-19 01:31:23.985[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_hit_rate[0m:[36m137[0m - [1mComputed hit rate: 0.9967741935483871[0m
[32m2024-09-19 01:31:23.986[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_mean_reciprocal_rank[0m:[36m152[0m - [1mComputed Mean Reciprocal Rank (MRR): 0.9748387096774194[0m
[32m2024-09-19 01:31:23.986[0m | [1mINFO    [0m | [36m__main__[0m:[36mevaluate[0m:[36m59[0m - [1mModel evaluation complete. Hit Rate: 0.9967741935483871, MRR: 0.9748387096774194[0m


In [63]:
# finetune_results_df = pd.DataFrame(finetune_results)
finetune_results = finetune_evaluator.evaluation_results
logger.info(f"Fine-tuned Arctic model hit rate: {finetune_evaluator.hit_rate}")
logger.info(f"Fine-tuned Arctic model MRR: {finetune_evaluator.mrr}")

finetune_report = finetune_evaluator.generate_evaluation_report()
logger.info("Generated evaluation report for fine-tuned model")
print(finetune_report.head())



[32m2024-09-19 01:31:24.160[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m3[0m - [1mFine-tuned Arctic model hit rate: 0.9967741935483871[0m
[32m2024-09-19 01:31:24.162[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m4[0m - [1mFine-tuned Arctic model MRR: 0.9748387096774194[0m
[32m2024-09-19 01:31:24.163[0m | [1mINFO    [0m | [36m__main__[0m:[36mgenerate_evaluation_report[0m:[36m158[0m - [1mGenerating evaluation report.[0m
[32m2024-09-19 01:31:24.166[0m | [1mINFO    [0m | [36m__main__[0m:[36mgenerate_evaluation_report[0m:[36m170[0m - [1mEvaluation report generated.[0m
[32m2024-09-19 01:31:24.167[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m7[0m - [1mGenerated evaluation report for fine-tuned model[0m


                            question_id  \
0  ee57e704-11a2-494e-a23e-529d91141a5b   
1  7b478e5a-3991-46f7-9935-726a76ee7b3f   
2  9867205c-a82f-49f4-8fd5-713ab8d8bf14   
3  682c1384-2df1-4788-896a-45629b573a34   
4  9c287439-c95d-4d88-b82c-27ed2833a61d   

                                            question  is_hit  rank  \
0  How does the size of the provider impact the c...    True   1.0   
1  What are the simplified ways of compliance for...    True   1.0   
2  How can AI systems be designed to ensure equit...    True   1.0   
3  What ethical considerations should be taken in...    True   1.0   
4  What information must be included in the EU de...    True   1.0   

                                       retrieved_ids  \
0  [0647e8da-06dc-418d-a5db-7d51468ae3de, 358f7e5...   
1  [0647e8da-06dc-418d-a5db-7d51468ae3de, a401596...   
2  [096bff58-c010-431a-a0b9-b8661bd551b4, 4af04c0...   
3  [096bff58-c010-431a-a0b9-b8661bd551b4, 273b51c...   
4  [bb37130f-2f53-453d-bb06-f47baded20c7

In [64]:
# finetune_hit_rate = finetune_results_df["is_hit"].mean()
# finetune_hit_rate
finetune_hit_rate = finetune_report["is_hit"].mean()
finetune_hit_rate

0.9967741935483871

In [65]:
evaluation_data = [
    ("OpenAI Embeddings", te3_results),
    ("Pre-trained Arctic Model", arctic_results),
    ("Fine-tuned Arctic Model", finetune_results),
]

# Generate comparison report
comparison_df = Evaluator.compare_evaluations(evaluation_data)
print("Comparison of Evaluation Metrics:")
print(comparison_df)


[32m2024-09-19 01:31:24.403[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompare_evaluations[0m:[36m181[0m - [1mStarting comparison of evaluation results.[0m
[32m2024-09-19 01:31:24.404[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompare_evaluations[0m:[36m210[0m - [1mComputed metrics for OpenAI Embeddings - Hit Rate: 0.9870967741935484, MRR: 0.9381182795698925[0m
[32m2024-09-19 01:31:24.405[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompare_evaluations[0m:[36m210[0m - [1mComputed metrics for Pre-trained Arctic Model - Hit Rate: 0.5516129032258065, MRR: 0.3795698924731183[0m
[32m2024-09-19 01:31:24.406[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompare_evaluations[0m:[36m210[0m - [1mComputed metrics for Fine-tuned Arctic Model - Hit Rate: 0.9967741935483871, MRR: 0.9748387096774194[0m
[32m2024-09-19 01:31:24.408[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompare_evaluations[0m:[36m216[0m - [1mComparison DataFrame created.[0m


Comparison of Evaluation Metrics:
                 model_name  hit_rate       mrr  total_questions
0         OpenAI Embeddings  0.987097  0.938118              310
1  Pre-trained Arctic Model  0.551613  0.379570              310
2   Fine-tuned Arctic Model  0.996774  0.974839              310


# 🤝 Breakout Room #2

## Task 1: Vibe Checking the RAG Pipeline

We're going to use our RAG pipeline to vibe check on some common phrases now that we've modified it!

### Creating New Chunks

In order to try and evaluate our system more fairly, let's create new chunks that we will use to create our Vector Store.

In [None]:
text_splitter = RecursiveCharacterTextSplitter(
    chunk_size = 600,
    chunk_overlap  = 50,
    length_function = len
)

training_documents = text_splitter.split_documents(training_documents_loaded.load())

### Base Chain

We'll start by constructing our base chain, which will use the untrained retrieval model.

#### R - Retrieval

In [None]:
from langchain_community.vectorstores import FAISS

base_vectorstore = FAISS.from_documents(training_documents, huggingface_embeddings)
base_retriever = base_vectorstore.as_retriever(search_kwargs={"k": 6})

#### A - Augmented

In [None]:
from langchain_core.prompts import ChatPromptTemplate

RAG_PROMPT = """\
Given a provided context and a question, you must answer the question. If you do not know the answer, you must state that you do not know.

Context:
{context}

Question:
{question}

Answer:
"""

rag_prompt_template = ChatPromptTemplate.from_template(RAG_PROMPT)

#### G - Generation

In [None]:
rag_llm =  ChatOpenAI(
    model="gpt-4o-mini",
    temperature=0
)

#### RAG - LCEL RAG Pipeline

In [None]:
from operator import itemgetter
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough, RunnableParallel

base_rag_chain = (
    {"context": itemgetter("question") | base_retriever, "question": itemgetter("question")}
    | RunnablePassthrough.assign(context=itemgetter("context"))
    | {"response": rag_prompt_template | rag_llm | StrOutputParser(), "context": itemgetter("context")}
)

In [None]:
base_rag_chain.invoke({"question" : "Why does the EU want to regulate AI?"})["response"]

In [None]:
base_rag_chain.invoke({"question" : "What are the codes of practice?"})["response"]

In [None]:
base_rag_chain.invoke({"question" : "How many parameters is too many parameters?"})["response"]

In [None]:
base_rag_chain.invoke({"question" : "What is an emotion recognition system and why is it important?"})["response"]

### Fine-tuned Embedding Model

Now let's rebuild our RAG chain with the Fine-tuned model - the only component we need to change is our `FAISS` vectorstore!

In [None]:
finetune_vectorstore = FAISS.from_documents(training_documents, finetune_embeddings)
finetune_retriever = finetune_vectorstore.as_retriever(search_kwargs={"k": 6})

In [None]:
finetune_rag_chain = (
    {"context": itemgetter("question") | finetune_retriever, "question": itemgetter("question")}
    | RunnablePassthrough.assign(context=itemgetter("context"))
    | {"response": rag_prompt_template | rag_llm | StrOutputParser(), "context": itemgetter("context")}
)

In [None]:
finetune_rag_chain.invoke({"question" : "Why does the EU want to regulate AI?"})["response"]

In [None]:
finetune_rag_chain.invoke({"question" : "What are the codes of practice?"})["response"]

In [None]:
finetune_rag_chain.invoke({"question" : "How many parameters is too many parameters?"})["response"]

In [None]:
finetune_rag_chain.invoke({"question" : "What is an emotion recognition system and why is it important?"})["response"]

#####❓Question #2:

Which LCEL RAG Chain do you think answered the questions better, and why?

## Task 2: RAGAS Evaluation

It's great to have some idea of how our system is doing based on vibe-checks, but let's use RAGAS to provide more insight info. on how things are improving!

In [None]:
!pip install -qU ragas

### RAGAS Synthetic Testset Generation

First things first, we need to generate some data to test our model on.

Let's use our test data that we created before as a base!

In [None]:
from ragas.testset.generator import TestsetGenerator
from ragas.testset.evolutions import simple, reasoning, multi_context
from langchain_openai import OpenAIEmbeddings

generator_llm = ChatOpenAI(model="gpt-3.5-turbo")
critic_llm = ChatOpenAI(model="gpt-4o-mini")
embeddings = OpenAIEmbeddings()

In [None]:
generator = TestsetGenerator.from_langchain(
    generator_llm,
    critic_llm,
    embeddings
)

In [None]:
testset = generator.generate_with_langchain_docs(test_split_documents, test_size=20, distributions={simple: 0.5, reasoning: 0.25, multi_context: 0.25})

In [None]:
testset.to_pandas().head()

### Generating Answer Datasets

For each of our pipelines, let's generate answers to these questions!

Once we have our: Questions, Answers, Contexts, Ground Truths we can move on to evaluating our datasets!

In [None]:
from datasets import Dataset

def generate_answers(chain, testset):
  answers = []
  contexts = []
  questions = testset.to_pandas()["question"].values.tolist()
  ground_truths = testset.to_pandas()["ground_truth"].values.tolist()

  for question in tqdm(questions):
    answer = chain.invoke({"question" : question})
    answers.append(answer["response"])
    contexts.append([context.page_content for context in answer["context"]])

  return Dataset.from_dict({
      "question" : questions,
      "answer" : answers,
      "contexts" : contexts,
      "ground_truth" : ground_truths
  })

In [None]:
base_dataset = generate_answers(base_rag_chain, testset)

In [None]:
finetune_dataset = generate_answers(finetune_rag_chain, testset)

### Evaluating Using the Test Set

Now that we have a test set - it's time to evaluate our pipelines with it!

In [None]:
from ragas.metrics import (
    context_recall,
    context_precision,
)

In [None]:
from ragas import evaluate

result = evaluate(
    base_dataset,
    metrics=[
        context_precision,
        context_recall,
    ],
)

In [None]:
result

In [None]:
result.to_pandas().head()

In [None]:
result = evaluate(
    finetune_dataset,
    metrics=[
        context_precision,
        context_recall,
    ],
)

In [None]:
result

In [None]:
result.to_pandas().head()

#### 🏗️ Activity #3:

Discuss changes that you'd make to this pipeline based on the performance improvements that you see with RAGAS and the fine-tuning.

Come up with 3 changes, and then we'll discuss these options as a group!

1. ...
2. ...
3. ...