<a href="https://colab.research.google.com/github/liyanonline/Artificial-Intelligence-with-Python/blob/master/examples/fine-tune-modernbert-rag.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Fine-tune ModernBERT with Synthetic Data for RAG

This notebook demonstrates the fine-tuning process of `modernbert-embed-base` using synthetic data tailored for the Retrieval-Augmented Generation (RAG) model.

It provides a complete walkthrough of the fine-tuning process after generating synthetic data using the Synthetic Data Generator. For a comprehensive explanation of the methodology and additional details, refer to the blog post: [Fine-tune ModernBERT for RAG with Synthetic Data](https://huggingface.co/blog/fine-tune-modernbert-for-rag-with-synthetic-data).

## Getting Started

### Install the Dependencies

In [None]:
!pip install torch
!pip install datasets
!pip install sentence-transformers
!pip install haystack-ai
!pip install git+https://github.com/huggingface/transformers.git  # for the latest version of transformers

### Import the Required Libraries

In [None]:
import torch
from torch.utils.data import DataLoader

from datasets import load_dataset, concatenate_datasets, Dataset, DatasetDict


from sentence_transformers import (
    SentenceTransformer,
    SentenceTransformerModelCardData,
    CrossEncoder,
    InputExample,
    SentenceTransformerTrainer,
)
from sentence_transformers.losses import TripletLoss
from sentence_transformers.training_args import (
    SentenceTransformerTrainingArguments,
    BatchSamplers,
)
from sentence_transformers.evaluation import TripletEvaluator
from sentence_transformers.cross_encoder.evaluation import CECorrelationEvaluator


from haystack import Document, Pipeline
from haystack.document_stores.in_memory import InMemoryDocumentStore
from haystack.components.embedders import (
    SentenceTransformersDocumentEmbedder,
    SentenceTransformersTextEmbedder,
)
from haystack.components.rankers import SentenceTransformersDiversityRanker
from haystack.components.retrievers.in_memory import InMemoryEmbeddingRetriever
from haystack.components.builders import ChatPromptBuilder
from haystack.components.generators.chat import HuggingFaceAPIChatGenerator
from haystack.dataclasses import ChatMessage
from haystack.utils import Secret
from haystack.utils.hf import HFGenerationAPIType

### Configure the Environment

In [None]:
MODEL = "nomic-ai/modernbert-embed-base"
REPO_NAME = "sdiazlor" # your HF username here
MODEL_NAME_BIENCODER = "modernbert-embed-base-biencoder-human-rights"
MODEL_NAME_CROSSENCODER = "modernbert-embed-base-crossencoder-human-rights"

In [None]:
if torch.backends.mps.is_available():
    device = "mps"
elif torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

print(f"Using device: {device}")

Using device: mps


## Pre-process the Synthetic Data

In [None]:
# Combine the generated datasets from files and prompts

dataset_rag_from_file = load_dataset(f"{REPO_NAME}/rag-human-rights-from-files", split="train")
dataset_rag_from_prompt = load_dataset(f"{REPO_NAME}/rag-human-rights-from-prompt", split="train")

combined_rag_dataset = concatenate_datasets(
    [dataset_rag_from_file, dataset_rag_from_prompt]
)

combined_rag_dataset

Dataset({
    features: ['context', 'question', 'response', 'positive_retrieval', 'negative_retrieval', 'positive_reranking', 'negative_reranking'],
    num_rows: 1000
})

In [None]:
# Filter out examples with empty or NaN values

def filter_empty_or_nan(example):
    return all(
        value is not None and str(value).strip() != "" for value in example.values()
    )

filtered_rag_dataset = combined_rag_dataset.filter(filter_empty_or_nan).shuffle(seed=42)
filtered_rag_dataset

Dataset({
    features: ['context', 'question', 'response', 'positive_retrieval', 'negative_retrieval', 'positive_reranking', 'negative_reranking'],
    num_rows: 828
})

In [None]:
# Rename, select and reorder columns according to the expected format for the SentenceTransformer and CrossEncoder models

def rename_and_reorder_columns(dataset, rename_map, selected_columns):
    for old_name, new_name in rename_map.items():
        if old_name in dataset.column_names:
            dataset = dataset.rename_column(old_name, new_name)
    dataset = dataset.select_columns(selected_columns)
    return dataset

clean_rag_dataset_biencoder = rename_and_reorder_columns(
    filtered_rag_dataset,
    rename_map={"context": "anchor", "positive_retrieval": "positive", "negative_retrieval": "negative"},
    selected_columns=["anchor", "positive", "negative"],
)

clean_rag_dataset_crossencoder = rename_and_reorder_columns(
    filtered_rag_dataset,
    rename_map={"context": "anchor", "positive_retrieval": "positive"}, #TODO
    selected_columns=["anchor", "positive"],
)

print(clean_rag_dataset_biencoder)
print(clean_rag_dataset_crossencoder)

Dataset({
    features: ['anchor', 'positive', 'negative'],
    num_rows: 828
})
Dataset({
    features: ['anchor', 'positive'],
    num_rows: 828
})


In [None]:
# Add scores to train the CrossEncoder model, which requires sentence pairs with a score indicating how related they are.
# Check the available models: https://huggingface.co/spaces/mteb/leaderboard

model_reranking = CrossEncoder(
    model_name="Snowflake/snowflake-arctic-embed-m-v1.5", device=device
)

def add_reranking_scores(batch):
    pairs = list(zip(batch["anchor"], batch["positive"]))
    batch["score"] = model_reranking.predict(pairs)
    return batch

clean_rag_dataset_crossencoder = clean_rag_dataset_crossencoder.map(
    add_reranking_scores, batched=True, batch_size=250
)
clean_rag_dataset_crossencoder

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at Snowflake/snowflake-arctic-embed-m-v1.5 and are newly initialized: ['classifier.bias', 'classifier.weight', 'pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Map:   0%|          | 0/828 [00:00<?, ? examples/s]

Dataset({
    features: ['anchor', 'positive', 'score'],
    num_rows: 828
})

In [None]:
# Split the datasets into training and evaluation sets
def split_dataset(dataset, train_size=0.8, seed=42):
    train_eval_split = dataset.train_test_split(test_size=1 - train_size, seed=seed)

    dataset_dict = DatasetDict(
        {"train": train_eval_split["train"], "eval": train_eval_split["test"]}
    )

    return dataset_dict

dataset_rag_biencoder = split_dataset(clean_rag_dataset_biencoder)
dataset_rag_crossencoder = split_dataset(clean_rag_dataset_crossencoder)

print(dataset_rag_biencoder)
print(dataset_rag_crossencoder)

DatasetDict({
    train: Dataset({
        features: ['anchor', 'positive', 'negative'],
        num_rows: 662
    })
    eval: Dataset({
        features: ['anchor', 'positive', 'negative'],
        num_rows: 166
    })
})
DatasetDict({
    train: Dataset({
        features: ['anchor', 'positive', 'score'],
        num_rows: 662
    })
    eval: Dataset({
        features: ['anchor', 'positive', 'score'],
        num_rows: 166
    })
})


## Train the Bi-Encoder model for Retrieval

In [None]:
# Load the base model and create the SentenceTransformer model
model_biencoder = SentenceTransformer(
    MODEL,
    model_card_data=SentenceTransformerModelCardData(
        language="en",
        license="apache-2.0",
        model_name=MODEL_NAME_BIENCODER,
    ),
)
model_biencoder.gradient_checkpointing_enable()  # Enable gradient checkpointing to save memory

In [None]:
# Select the TripleLoss loss function which requires sentence triplets (anchor, positive, negative)
# Check the available losses: https://sbert.net/docs/sentence_transformer/loss_overview.html

loss_biencoder = TripletLoss

In [None]:
# Define the training arguments for the SentenceTransformer model
# Customize them as needed for your requirements

training_args = SentenceTransformerTrainingArguments(
    output_dir=f"models/{MODEL_NAME_BIENCODER}",
    num_train_epochs=3,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=4,
    per_device_eval_batch_size=4,
    warmup_ratio=0.1,
    learning_rate=2e-5,
    lr_scheduler_type="cosine",
    fp16=False,  # or True if stable on your MPS device
    bf16=False,
    batch_sampler=BatchSamplers.NO_DUPLICATES,
    eval_strategy="epoch",
    save_strategy="epoch",
    save_total_limit=2,
    logging_steps=100,
    load_best_model_at_end=True,
    use_mps_device=(device == "mps"),
)



In [None]:
# Define the evaluator to assess the performance of the model
triplet_evaluator = TripletEvaluator(
    anchors=dataset_rag_biencoder["eval"]["anchor"],
    positives=dataset_rag_biencoder["eval"]["positive"],
    negatives=dataset_rag_biencoder["eval"]["negative"],
)

In [None]:
# Train the model. This will take some time depending on the size of the dataset and the model
# Remember to adjust the training arguments according to your requirements

trainer = SentenceTransformerTrainer(
    model=model_biencoder,
    args=training_args,
    train_dataset=dataset_rag_biencoder["train"],
    eval_dataset=dataset_rag_biencoder["eval"],
    loss=loss_biencoder,
    evaluator=triplet_evaluator,
)
trainer.train()

  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]


Epoch,Training Loss,Validation Loss,Cosine Accuracy
1,No log,3.655929,0.96988
2,14.374000,3.498395,0.981928


Computing widget examples:   0%|          | 0/1 [00:00<?, ?example/s]

  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]


In [None]:
# Save the model to the local directory and push it to the Hub
model_biencoder.save_pretrained(f"models/{MODEL_NAME_BIENCODER}")
model_biencoder.push_to_hub(f"{REPO_NAME}/{MODEL_NAME_BIENCODER}")

## Train the Cross-Encoder model for Ranking

In [None]:
# Prepare the training and evaluation samples for the CrossEncoder model

train_samples = []
for row in dataset_rag_crossencoder["train"]:
    # Suppose 'score' is a float or an integer that you want to predict
    train_samples.append(
        InputExample(texts=[row["anchor"], row["positive"]], label=float(row["score"]))
    )

eval_samples = []
for row in dataset_rag_crossencoder["eval"]:
    eval_samples.append(
        InputExample(texts=[row["anchor"], row["positive"]], label=float(row["score"]))
    )

# Initialize the DataLoader for the training samples
train_dataloader = DataLoader(train_samples, shuffle=True, batch_size=4)

In [None]:
# Initialize the CrossEncoder model. Set the number of labels to 1 for regression tasks
model_crossencoder = CrossEncoder(model_name=MODEL, num_labels=1)

In [None]:
# Define the evaluator
evaluator = CECorrelationEvaluator.from_input_examples(eval_samples)

In [None]:
# Train the CrossEncoder model

model_crossencoder.fit(
    train_dataloader=train_dataloader,
    evaluator=evaluator,
    epochs=3,
    warmup_steps=500,
    output_path=f"models/{MODEL_NAME_CROSSENCODER}",
    save_best_model=True,
)

Epoch:   0%|          | 0/3 [00:00<?, ?it/s]

Iteration:   0%|          | 0/166 [00:00<?, ?it/s]

Iteration:   0%|          | 0/166 [00:00<?, ?it/s]

Iteration:   0%|          | 0/166 [00:00<?, ?it/s]

In [None]:
# Save the model to the local directory and push it to the Hub
model_crossencoder.save_pretrained(f"models/{MODEL_NAME_CROSSENCODER}")
model_crossencoder.push_to_hub(f"{REPO_NAME}/{MODEL_NAME_CROSSENCODER}")

## Build the RAG Pipeline

The following section is inspired by the Haystack tutorial, check it for further details: [Creating Your First QA Pipeline with Retrieval-Augmentation](https://haystack.deepset.ai/tutorials/27_first_rag_pipeline)

In [None]:
# Add the documents to the DocumentStore
# Use the already chunked documents from original datasets

df = combined_rag_dataset.to_pandas()
df = df.drop_duplicates(subset=["context"]) # drop duplicates based on "context" column
df = df.sample(n=10, random_state=42) # optional: sample a subset of the dataset
dataset = Dataset.from_pandas(df)

docs = [Document(content=doc["context"]) for doc in dataset]

In [None]:
# Initialize the document store and store the documents with the embeddings using our bi-encoder model

document_store = InMemoryDocumentStore()
doc_embedder = SentenceTransformersDocumentEmbedder(
    model=f"{REPO_NAME}/{MODEL_NAME_BIENCODER}",
)
doc_embedder.warm_up()

docs_with_embeddings = doc_embedder.run(docs)
document_store.write_documents(docs_with_embeddings["documents"])

text_embedder = SentenceTransformersTextEmbedder(
    model=f"{REPO_NAME}/{MODEL_NAME_BIENCODER}",
)

In [None]:
# Initialize the retriever (our bi-encoder model) and the ranker (our cross-encoder model)

retriever = InMemoryEmbeddingRetriever(document_store)
ranker = SentenceTransformersDiversityRanker(
    model=f"{REPO_NAME}/{MODEL_NAME_CROSSENCODER}"
)

In [None]:
# Define the prompt builder and the chat generator to interact with the models using the HF Serverless Inference API

template = [
    ChatMessage.from_user(
        """
Given the following information, answer the question.

Context:
{% for document in documents %}
    {{ document.content }}
{% endfor %}

Question: {{question}}
Answer:
"""
    )
]

prompt_builder = ChatPromptBuilder(template=template)

chat_generator = HuggingFaceAPIChatGenerator(
    api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API,
    api_params={"model": "meta-llama/Llama-3.1-8B-Instruct"},
    token=Secret.from_env_var("HF_TOKEN"),
)

In [None]:
# Initialize the pipeline with the components

rag_pipeline = Pipeline()
rag_pipeline.add_component("text_embedder", text_embedder)
rag_pipeline.add_component("retriever", retriever)
rag_pipeline.add_component("ranker", ranker)
rag_pipeline.add_component("prompt_builder", prompt_builder)
rag_pipeline.add_component("llm", chat_generator)

In [None]:
# Connect the components to each other

rag_pipeline.connect("text_embedder.embedding", "retriever.query_embedding")
rag_pipeline.connect("retriever.documents", "ranker.documents")
rag_pipeline.connect("ranker", "prompt_builder")
rag_pipeline.connect("prompt_builder.prompt", "llm.messages")

<haystack.core.pipeline.pipeline.Pipeline object at 0x32e75b4d0>
🚅 Components
  - text_embedder: SentenceTransformersTextEmbedder
  - retriever: InMemoryEmbeddingRetriever
  - ranker: SentenceTransformersDiversityRanker
  - prompt_builder: ChatPromptBuilder
  - llm: HuggingFaceAPIChatGenerator
🛤️ Connections
  - text_embedder.embedding -> retriever.query_embedding (List[float])
  - retriever.documents -> ranker.documents (List[Document])
  - ranker.documents -> prompt_builder.documents (List[Document])
  - prompt_builder.prompt -> llm.messages (List[ChatMessage])

In [None]:
# Make a query to the pipeline without references included in your documentation
question = "How many human rights there are?"

response = rag_pipeline.run(
    {
        "text_embedder": {"text": question},
        "prompt_builder": {"question": question},
        "ranker": {"query": question},
    }
)

print(response["llm"]["replies"][0].text)

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

It seems that there is not enough information given in the human rights protocols provided to accurately answer the question. However, we can inform you that there are several types of human rights documents that this could be referring too. Event the most widely respected declared world document on human rights for Example - Exernal and some Individual (Part 1 Art.) and some other attempted Separation apart include: The convention lists several key rights such as 

1. Right to Life 
2. Right to Liberty and Security 
3. Freedom from Torture 
4. Freedom from Slavery 
5. Right to a Fair Trial 
6. No Punishment without Law 
7. Respect for Family Life 
... (and throughout given information 44 protocals  - are actually chapter and not... How is the answer 
 

Not possible to answer your question due to lack of information, however we can tell you Event the most widely respected declared world document on human rights.


In [None]:
# Make a query to the pipeline with references included in your documentation
question = "What's the Right of Fair Trial?"

response = rag_pipeline.run(
    {
        "text_embedder": {"text": question},
        "prompt_builder": {"question": question},
        "ranker": {"query": question},
    }
)

print(response["llm"]["replies"][0].text)

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

The information you provided does not directly list the "Right of Fair Trial" but looking under articles of the Convention for the Protection of Human Rights and Fundamental Freedoms, Article 6, also known as the Right to a Fair Trial, gives a clear idea.

 Article 6. Right to a fair Trial
 

1. Everyone is entitled to a fair and public hearing within a reasonable time by an independent and impartial tribunal established by law.
 
2, everybody shall be presumed innocent until proven guilty by a final decision of a competent court.
 
3. Everyone charged with a criminal offence has the following minimum rights:

      a to be informed promptly, in a language which he understands and in detail, of the charges, if any, against him.
      b to have adequate time and facilities for the preparation of his defence.
      c to defend himself in person or through legal assistance of his own choosing or, if he has not sufficient means to pay for legal assistance, to be given it free when the inte