[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/aurelio-labs/cookbook/blob/main/information-retrieval/sentence-transformers/v3-fine-tuning.ipynb) [![Open nbviewer](https://raw.githubusercontent.com/pinecone-io/examples/master/assets/nbviewer-shield.svg)](https://nbviewer.org/github/aurelio-labs/cookbook/blob/main/information-retrieval/sentence-transformers/v3-fine-tuning.ipynb)

# Fine-tune Embedding Models with Sentence Transformers 3

In this blog post, we will walk through the process of fine-tuning embedding models using Sentence Transformers 3 to enhance Retrieval-Augmented Generation (RAG) performance.

## Install the Necessary Libraries
Install the following libraries:
- Pytorch
- Sentence Transformers (HF)
- Transformers (HF)
- Datasets (HF)

We are currently using Python 3.11.5.

In [None]:
accelerate launch –multi-gpu –num_processes=2 v4_ft_embedding.py

In [5]:
!source activate agent
!pip install --q \
    "torch==2.6.0" \
    "sentence-transformers" \
    "datasets==2.19.1"  \
    "transformers" \
    "accelerate"

[0m^C
[31mERROR: Operation cancelled by user[0m[31m
[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.0[0m[39;49m -> [0m[32;49m25.1.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython -m pip install --upgrade pip[0m


After installing the necessary libraries, you should register on [Hugging Face](https://huggingface.co/join) as we are going to use Hugging Face Hub to push our models and training logs.

Get your access token [here](https://huggingface.co/settings/tokens)

In [11]:
# # Log into your HF account and store your token (access key) on the disk
# from huggingface_hub import login

# # login(token="ADD YOUR TOKEN HERE", add_to_git_credential=True)
# login(token="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", add_to_git_credential=False)

from huggingface_hub import notebook_login
from huggingface_hub import interpreter_login

notebook_login()
# interpreter_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [22]:
import os
os.environ["WANDB_API_KEY"]= "e7ed666bc81a0db87a16f40b0055262deaac80b0"
os.environ["WANDB_PROJECT"] = "Fine-tune model with Sentence Transformer"
os.environ["WANDB_NAME"] = "ft-with-st-v3"

## 加载预训练模型

In [7]:
from sentence_transformers import SentenceTransformer

model = SentenceTransformer("models/stella-fine-tuned-v4.5")

## Dataset preparation

The Hugging Face Hub has a lot of datasets that can be used to fine-tune embeddings models.You can take a look [here](https://sbert.net/docs/sentence_transformer/dataset_overview.html) at what sort of dataset structure should your dataset follow in order to be able to use it for fine-tunning embeddings.

We are going to use [enelpol/rag-mini-bioasq](https://huggingface.co/datasets/enelpol/rag-mini-bioasq), which includes 4,719 question-answer passages from the BioASQ challenges on biomedical semantic indexing and question answering (QA) [dataset for task b BioASQ11](http://participants-area.bioasq.org/datasets/), which can be used as *Positive Pair* configuration.

We have to load the dataset, and we can do it using the HF datasets library.


In [None]:
from datasets import load_dataset

# Load dataset from HF hub

# (anchor, positive, negative)
all_nli_triplet_train = load_dataset("neulab/conala", "triplet", split="train[:500]", cache_dir="./dataset")
# (sentence1, sentence2) + score
stsb_pair_score_train = load_dataset("sentence-transformers/stsb", split="train[:500]", cache_dir="./dataset")

# (anchor, positive, negative)
all_nli_triplet_dev = load_dataset("sentence-transformers/all-nli", "triplet", split="dev[:400]", cache_dir="./dataset")
# (sentence1, sentence2, score)
stsb_pair_score_dev = load_dataset("sentence-transformers/stsb", split="validation[:400]", cache_dir="./dataset")


加载自定义FAQ数据集，并通过 in-batch 和 hard-negtive-sampling 策略生成负样本

In [18]:
# Mine hard negatives
# https://github.com/UKPLab/sentence-transformers/releases/tag/v3.1.0
# pip install sentence-transformers[train]==3.1.1
# RuntimeError: The NVIDIA driver on your system is too old
from sentence_transformers.util import mine_hard_negatives
def hard_negtive_sampling(dataset):
    dataset = mine_hard_negatives(
        dataset=dataset,
        model=model,
        range_min=10,
        range_max=50,
        max_score=0.8,  # 负样本的最高相似得分，可用于控制难易程度
        relative_margin=0.05,         # 0.05 means that the negative is at most 95% as similar to the anchor as the positive
        num_negatives=5,  # 10 or less is recommended
        sampling_strategy="random",      # "top" means that we sample the top candidates as negatives
        batch_size=128,        # Adjust as needed
        use_faiss=True,               # Optional: Use faiss/faiss-gpu for faster
    )
    return dataset

# 加载自定义数据集
import pandas as pd
custom_samples = pd.read_csv(
    "processed_crm_similar_pair.csv", 
    header=0, sep=",", encoding="utf-8", index_col=False
)
from datasets import Dataset
custom_dataset = Dataset.from_dict({
    "anchor": custom_samples['standard_question'],
    "positive": custom_samples['similar_question'],
})

custom_dataset = hard_negtive_sampling(custom_dataset)
# (anchor, positive, negative)

# 切分训练集和测试集
custom_dataset = custom_dataset.train_test_split(test_size=0.1, seed=123, shuffle=True)
custom_dataset_train = custom_dataset["train"]
custom_dataset_dev = custom_dataset["test"]

Found 6494 unique queries out of 89816 total queries.
Found an average of 13.831 positives per query.


Batches:   0%|          | 0/699 [00:00<?, ?it/s]

Batches:   0%|          | 0/51 [00:00<?, ?it/s]

Querying FAISS index: 100%|██████████| 1/1 [01:39<00:00, 99.46s/it]


Metric       Positive       Negative     Difference
Count          89,816         93,685               
Mean           0.8957         0.6806         0.2093
Median         0.9107         0.6904         0.2056
Std            0.0723         0.0679         0.0843
Min            0.1593         0.3519        -0.3537
25%            0.8568         0.6395         0.1540
50%            0.9107         0.6904         0.2056
75%            0.9503         0.7327         0.2611
Max            1.0000         0.7946         0.6310
Skipped 183,610 potential negatives (55.44%) due to the relative_margin of 0.05.
Skipped 9,918 potential negatives (6.72%) due to the max_score of 0.8.
Could not find enough negatives for 355395 samples (79.14%). Consider adjusting the range_max, range_min, relative_margin and max_score parameters if you'd like to find more valid negatives.


In [23]:
custom_dataset_train[:20]

{'anchor': ['蓝牙名称是什么？',
  '保险杠有什么用？',
  '什么是珍珠漆？',
  '如何确认端到端VLM报名是否成功？',
  '车身高度随速调节是什么功能？',
  '中控屏可以在行驶时播放视频吗？',
  'IMEI是什么？',
  '实拍星环是有光斑的，星环大灯的帧率是1/300秒是不是每次拍照星环灯都会频闪？',
  '车辆的节能模式开启后车辆的哪些功能会关闭或禁用？',
  '账号管理怎么进入？',
  '车辆开启空调后多长时间可以出热风？',
  '副驾单人充气床垫是否适配L6',
  '如何设置上车后不自动展开后视镜?',
  '在外面加装踏板是否会影响车辆上牌？',
  '在理想汽车中心购买轮胎后，可以享受什么保障？',
  '用车服务包购买后是否可以申请退款？',
  '屏幕怎么打开？屏幕会自动点亮吗？',
  'PM2.5快速净化风速控制逻辑？',
  '车辆播报导航信息时，播放的多媒体音量会发生变化吗？',
  '车辆车顶的激光雷达如何清洗？'],
 'positive': ['如何查看汽车的蓝牙设备名称？',
  '保险杠在车辆行驶中的主要功能是什么?',
  '你知道吗？有种漆叫珍珠漆，它的光泽是怎么做到的？',
  '提交端到端VLM万人团的申请后，如何得知报名结果?',
  '我的车能根据速度自动调整车身高度吗？',
  '开车时，中控屏能播放电影吗？',
  '如何找到我设备的IMEI识别码？',
  '星环灯在镜头里老闪，是不是因为它的帧率设置？',
  '节能模式下，音响系统会有什么变化？',
  '怎么才能找到账号管理？',
  '在寒冷的早晨，理想电动车空调需要多长时间才能让车内变暖?',
  '请问L6适配的充气床垫型号？',
  '在车机上如何设定不自动展开后视镜?',
  '外部加装踏板，对新车上牌有无影响？',
  '理想汽车中心售出的轮胎，有无额外的客户保障?',
  '购车后购买的服务包，我还能申请退款吗?',
  '车辆解锁后，仪表盘和中控屏会自动打开吗？',
  '当PM2.5过高时，如何让汽车自动提高风速净化?',
  '播放音乐时导航有语音，会自动减小音乐声吗?',
  '清洗车顶激光雷达时应避免什么？'],
 'negative': ['在App里怎么给车辆的蓝牙起个新名字?',


In [None]:
# Combine all datasets into a dictionary with dataset names to datasets
train_dataset = {
    "all-nli-triplet": all_nli_triplet_train,
    "stsb": stsb_pair_score_train,
    "custom": custom_dataset_train
}

eval_dataset = {
    "all-nli-triplet": all_nli_triplet_dev,
    "stsb": stsb_pair_score_dev, 
    "custom": custom_dataset_dev
}

## Define model evaluator that will be used for training

In [None]:
# Evaluate the model
from sentence_transformers import evaluation

model_evaluator = evaluation.TripletEvaluator(
    anchors=eval_dataset['custom']["anchor"],
    positives=eval_dataset['custom']["positive"],
    negatives=eval_dataset['custom']["negative"],
    name="all_nli_dev",
)

## Define loss function that will be used for training

In this case, we are using the MultipleNegativesRankingLoss to fine-tune our embedding model. This choice is based on our dataset format, which consists of positive text pairs. You can take a look at [dataset format](https://sbert.net/docs/sentence_transformer/training_overview.html#dataset-format) information and [loss function](https://sbert.net/docs/sentence_transformer/loss_overview.html) information to determine which loss function to use based on your use case.


In [29]:
from sentence_transformers.losses import CoSENTLoss, MultipleNegativesRankingLoss

# (anchor, positive), (anchor, positive, negative)
mnrl_loss = MultipleNegativesRankingLoss(model)

# (sentence_A, sentence_B) + score
cosent_loss = CoSENTLoss(model)

# Create a mapping with dataset names to loss functions, so the trainer knows which loss to apply where.
losses={
    "all-nli-triplet": mnrl_loss,
    "stsb": cosent_loss,
    "custom": mnrl_loss,
}

## Fine-tune embedding model with SentenceTransformersTrainer

Now that we've prepared our data and model, we're ready to fine-tune our embedding model using the SentenceTransformersTrainer.

To configure our training process, we'll use the SentenceTransformerTrainingArguments class. This tool allows us to specify various parameters that can impact training performance and help with tracking and debugging. We'll be using parameter values based on those recommended in the [Sentence Transformers documentation](https://sbert.net/docs/sentence_transformer/training_overview.html#training-arguments). However, it's important to note that these are just starting points. For optimal results, you should experiment with different values tailored to your specific dataset and task.


In [27]:
from sentence_transformers import SentenceTransformerTrainingArguments
from sentence_transformers.training_args import BatchSamplers

 
args = SentenceTransformerTrainingArguments(
    # Required parameter:
    output_dir=os.getenv("WANDB_NAME"), # Save checkpoints
    # Optional training parameters:
    num_train_epochs=5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    learning_rate=2e-5,
    warmup_ratio=0.1,
    fp16=True,   # Loading model in mixed-precision
    bf16=False,  # Set to True if you have a GPU that supports BF16
    batch_sampler=BatchSamplers.NO_DUPLICATES,  # losses that use "in-batch negatives" benefit from no duplicates
    # Optional tracking/debugging parameters:
    eval_strategy="steps",  # transformers 版本 4.41.0 中引入的
    eval_steps=500,         # 每隔多少步的训练进行一次验证(执行evaluator)
    save_strategy="steps",
    save_steps=500,   # save checkpoints during training
    save_total_limit=3,
    logging_steps=500,
    report_to="wandb",
    run_name=os.getenv('WANDB_NAME'),

    load_best_model_at_end=True, # 如果设置为 True，在训练结束时加载根据评估器确定的最佳模型. defaults to `False`
    metric_for_best_model='eval_loss',   # 和 `load_best_model_at_end`联合使用，模型对比. eval_pearson_cosine
    greater_is_better=False,    # 跟前面两个参数一起使用
)

Currently using DataParallel (DP) for multi-gpu training, while DistributedDataParallel (DDP) is recommended for faster training. See https://sbert.net/docs/sentence_transformer/training/distributed.html for more information.


In [31]:
from sentence_transformers import SentenceTransformerTrainer
 
# trainer = SentenceTransformerTrainer(
#     model=model,
#     args=args,
#     train_dataset=train_dataset,
#     eval_dataset=eval_dataset,
#     loss=losses,
#     # evaluator=model_evaluator,
# )

trainer = SentenceTransformerTrainer(
    model=model,
    args=args,
    train_dataset=custom_dataset_train,
    eval_dataset=custom_dataset_dev,
    loss=mnrl_loss,
    # evaluator=model_evaluator,
)

Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


Computing widget examples:   0%|          | 0/1 [00:00<?, ?example/s]

In [None]:
# start training the model
trainer.train()

# save the model
model.save_pretrained("./sbert-model/final")
 
# #  The model will be saved to the hub and the output directory
# trainer.save_model()

# # (Optional) Push it to the Hugging Face Hub
# trainer.model.push_to_hub(os.getenv("WANDB_NAME"))

The training on 4k samples took around 1 minute on an Nvidia A10G instance of [Modal labs](https://modal.com/pricing). At the time of writing (July 2024), the instance costs 1.1 USD/hour which indicates a cost of less than 0.1 USD for the training.

What's pending now is the evaluation of the fine-tuned model using the 'model evaluator' from earlier.



In [15]:
from sentence_transformers import SentenceTransformer
 
fine_tuned_model = SentenceTransformer(
    args.output_dir, device="cuda" if torch.cuda.is_available() else "cpu"
)

# Evaluate the model
from sentence_transformers import evaluation
evaluator = evaluation.TripletEvaluator(
    anchors=eval_dataset['custom']["anchor"],
    positives=eval_dataset['custom']["positive"],
    negatives=eval_dataset['custom']["negative"],
    name="all_nli_dev",
)
 
fine_tuned_results = evaluator(fine_tuned_model)
fine_tuned_results

{'sentence-transformers/all-mpnet-base-v2_cosine_accuracy@1': 0.8458274398868458,
 'sentence-transformers/all-mpnet-base-v2_cosine_accuracy@3': 0.9335219236209336,
 'sentence-transformers/all-mpnet-base-v2_cosine_accuracy@5': 0.9476661951909476,
 'sentence-transformers/all-mpnet-base-v2_cosine_accuracy@10': 0.9618104667609618,
 'sentence-transformers/all-mpnet-base-v2_cosine_precision@1': np.float64(0.8458274398868458),
 'sentence-transformers/all-mpnet-base-v2_cosine_precision@3': np.float64(0.31117397454031115),
 'sentence-transformers/all-mpnet-base-v2_cosine_precision@5': np.float64(0.1895332390381895),
 'sentence-transformers/all-mpnet-base-v2_cosine_precision@10': np.float64(0.09618104667609616),
 'sentence-transformers/all-mpnet-base-v2_cosine_recall@1': np.float64(0.8458274398868458),
 'sentence-transformers/all-mpnet-base-v2_cosine_recall@3': np.float64(0.9335219236209336),
 'sentence-transformers/all-mpnet-base-v2_cosine_recall@5': np.float64(0.9476661951909476),
 'sentence-t

If we focus on only a couple of metrics that are more relevant in our case, we get the following information:

| Model | MRR@10 | NDCG@10 |
|-------|--------|---------|
| all-mpnet-base-v2 (Baseline) | 0.8347 | 0.8571 |
| bge-base-en-v1.5 | 0.8965 | 0.9122 |
| all-mpnet-base-v2 Fine-tuned | 0.8919 | 0.9093 |

The fine-tuned model shows significant improvements over the baseline model, with a 6.85% increase in MRR@10 and a 6.09% increase in NDCG@10. It reached the performance level of the bge-base-en-v1.5 embeddings.



## Conclusion

Embedding models play a crucial role in the success of Retrieval-Augmented Generation (RAG) applications, as the quality of retrieved context directly impacts the generated answers. Using the Sentence Transformers 3 library, we fine-tuned the all-mpnet-base-v2 model on a biomedical question-answering dataset. The results show substantial improvements:

- MRR@10 increased from 0.8347 to 0.8919 (6.85% improvement)
- NDCG@10 improved from 0.8571 to 0.9093 (6.09% improvement)

Our fine-tuned model achieved performance comparable to the more advanced bge-base-en-v1.5 model despite starting from a lower baseline.

The fine-tuning process has become highly accessible and efficient. With only 4,719 question-answer pairs, we were able to achieve these improvements in approximately 1 minute of training time on an Nvidia A10G GPU. The estimated cost for this training was less than 0.1 USD, making it a cost-effective approach for enhancing domain-specific retrieval tasks.
This shows the value of customizing embedding models for specific domains or use cases. Significant performance gains can be realized even with a relatively small dataset and minimal training time. 

