# Fine Tuning Embeddings
* Notebook by Adam Lang
* Date: 7/21/2024

# Overview
* In this notebook we will go over the process of fine tuning embedding models using the `SentenceTransformers` library.

# Process
* The process is as follows:
1. Download model
2. Dataset
 * Adapt to custom data.
3. Training arguments
4. Train
5. Test
 * Is the model performing well?
6. Model
 * Saving the model.

## Install/Import Packages

In [1]:
# transformers pytorch
!pip install transformers[torch]

Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch->transformers[torch])
  Using cached nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)
Collecting nvidia-cuda-runtime-cu12==12.1.105 (from torch->transformers[torch])
  Using cached nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (823 kB)
Collecting nvidia-cuda-cupti-cu12==12.1.105 (from torch->transformers[torch])
  Using cached nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (14.1 MB)
Collecting nvidia-cudnn-cu12==8.9.2.26 (from torch->transformers[torch])
  Using cached nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl (731.7 MB)
Collecting nvidia-cublas-cu12==12.1.3.1 (from torch->transformers[torch])
  Using cached nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl (410.6 MB)
Collecting nvidia-cufft-cu12==11.0.2.54 (from torch->transformers[torch])
  Using cached nvidia_cufft_cu12-11.0.2.54-py3-none-manylinux1_x86_64.whl (121.6 MB)
Collecting nvidia-curand

In [2]:
## SentenceTransformers library
!pip install -U sentence-transformers

Collecting sentence-transformers
  Downloading sentence_transformers-3.0.1-py3-none-any.whl (227 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/227.1 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m225.3/227.1 kB[0m [31m7.9 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m227.1/227.1 kB[0m [31m6.5 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: sentence-transformers
Successfully installed sentence-transformers-3.0.1


In [3]:
## huggingface datasets
!pip install datasets

Collecting datasets
  Downloading datasets-2.20.0-py3-none-any.whl (547 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/547.8 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m542.7/547.8 kB[0m [31m19.3 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m547.8/547.8 kB[0m [31m14.9 MB/s[0m eta [36m0:00:00[0m
Collecting pyarrow>=15.0.0 (from datasets)
  Downloading pyarrow-17.0.0-cp310-cp310-manylinux_2_28_x86_64.whl (39.9 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m39.9/39.9 MB[0m [31m39.3 MB/s[0m eta [36m0:00:00[0m
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m18.6 MB/s[0m eta [36m0:00:00[0m
Collecting requests>=2.32.2 (from datasets)
  Downloading requests-2.32.3-py3-none-any.whl (6

In [4]:
# imports
from sentence_transformers import SentenceTransformer
import torch

  from tqdm.autonotebook import tqdm, trange


## Load Dataset
* `sentence-transformers/all-nli`
* dataset: https://huggingface.co/datasets/sentence-transformers/all-nli

In [5]:
## load dataset
from datasets import load_dataset

# all-nli dataset
dataset = load_dataset("sentence-transformers/all-nli", "triplet")

# three datasets
train_dataset = dataset['train']
eval_dataset = dataset['dev']
test_dataset = dataset['test']

Downloading readme:   0%|          | 0.00/5.15k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/38.4M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/782k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/810k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/557850 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/6584 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/6609 [00:00<?, ? examples/s]

In [6]:
## train_dataset
train_dataset

Dataset({
    features: ['anchor', 'positive', 'negative'],
    num_rows: 557850
})

In [29]:
train_dataset[0:4]

{'anchor': ['A person on a horse jumps over a broken down airplane.',
  'Children smiling and waving at camera',
  'A boy is jumping on skateboard in the middle of a red bridge.',
  'Two blond women are hugging one another.'],
 'positive': ['A person is outdoors, on a horse.',
  'There are children present',
  'The boy does a skateboarding trick.',
  'There are women showing affection.'],
 'negative': ['A person is at a diner, ordering an omelette.',
  'The kids are frowning',
  'The boy skates down the sidewalk.',
  'The women are sleeping.']}

In [9]:
## closer look at train_dataset
train_dataset.to_pandas()

Unnamed: 0,anchor,positive,negative
0,A person on a horse jumps over a broken down a...,"A person is outdoors, on a horse.","A person is at a diner, ordering an omelette."
1,Children smiling and waving at camera,There are children present,The kids are frowning
2,A boy is jumping on skateboard in the middle o...,The boy does a skateboarding trick.,The boy skates down the sidewalk.
3,Two blond women are hugging one another.,There are women showing affection.,The women are sleeping.
4,"A few people in a restaurant setting, one of t...",The diners are at a restaurant.,The people are sitting at desks in school.
...,...,...,...
557845,and they're the ones that have screamed so muc...,They do not seem to like to pay taxes despite ...,They love increasing taxes and they actually w...
557846,"overtime, credit hours, or compensatory time),...","Overtime, credit hours, compensatory time, and...",There is no item for overtime.
557847,cook and eat and to have mainly mainly i guess...,Getting pleasure from making good food.,People don't like to eat my food.
557848,Rocker Stevie Nicks solved the wandering impla...,"Tina Turner's breasts drifted far apart, says ...",Breast implants are a perfectly safe form of p...


Summary:
* If you are using these embeddings for a RAG pipeline, this is the ideal way to approach fine-tuning your embedding models.
* We can see the dataset above contains the following:
1. `Anchor`: original sentence or query.
2. `Positive`: this is the correct/relevant response to the anchor query.
4. `Negative`: this is the incorrect or irrelevant response to the anchor query.

* A **Triplet** dataset is NOT entirely necessary to train a custom SentenceTransformers model.
   * You can use a **Pairs** dataset.
   * There are many additional datasets to fine-tune your model with on the huggingface hub but also available in the SBERT docs: https://sbert.net/docs/sentence_transformer/dataset_overview.html

In [10]:
## test_dataset
test_dataset

Dataset({
    features: ['anchor', 'positive', 'negative'],
    num_rows: 6609
})

In [11]:
## eval/validation data
eval_dataset

Dataset({
    features: ['anchor', 'positive', 'negative'],
    num_rows: 6584
})

In [32]:
557850 + 6609 + 6584

571043

In [37]:
0.01*2000

20.0

In [38]:
## cut down size of 3 datasets
train_dataset = dataset['train'].select(range(2000))
eval_dataset = dataset['dev'].select(range(20))
test_dataset = dataset['test'].select(range(20))

In [39]:
train_dataset.to_pandas()

Unnamed: 0,anchor,positive,negative
0,A person on a horse jumps over a broken down a...,"A person is outdoors, on a horse.","A person is at a diner, ordering an omelette."
1,Children smiling and waving at camera,There are children present,The kids are frowning
2,A boy is jumping on skateboard in the middle o...,The boy does a skateboarding trick.,The boy skates down the sidewalk.
3,Two blond women are hugging one another.,There are women showing affection.,The women are sleeping.
4,"A few people in a restaurant setting, one of t...",The diners are at a restaurant.,The people are sitting at desks in school.
...,...,...,...
1995,A man is outside.,A man standing in a narrow alley posing for a ...,A shirtless man wearing a white turban is clim...
1996,A man is outside.,A man standing in a narrow alley posing for a ...,A crowd of people dancing outside.
1997,A man is outside.,A man standing in a narrow alley posing for a ...,A woman in a striped shirt folds her arms whil...
1998,A man is outside.,A man standing in a narrow alley posing for a ...,A woman with a straw hat is sitting on steps o...


In [40]:
eval_dataset.to_pandas()

Unnamed: 0,anchor,positive,negative
0,Two women are embracing while holding to go pa...,Two woman are holding packages.,The men are fighting outside a deli.
1,"Two young children in blue jerseys, one with t...",Two kids in numbered jerseys wash their hands.,Two kids in jackets walk to school.
2,A man selling donuts to a customer during a wo...,A man selling donuts to a customer.,A woman drinks her coffee in a small cafe.
3,Two young boys of opposing teams play football...,boys play football,dog eats out of bowl
4,A man in a blue shirt standing in front of a g...,A man is wearing a blue shirt,A man is wearing a black shirt
5,"Under a blue sky with white clouds, a child re...",A child is reaching to touch the propeller of ...,A child is playing with a ball.
6,A woman is doing a cartwheel while wearing a b...,A woman is doing a cartwheel.,A woman is doing a cartwheel and falls on her ...
7,A woman is doing a cartwheel while wearing a b...,A woman is doing a cartwheel.,A woman is fixing her home.
8,Two men on bicycles competing in a race.,People are riding bikes.,A few people are catching fish.
9,A young boy in a field of flowers carrying a ball,boy in field,dog in pool


### Model Loading
* We will load the `BAAI/bge-large-en` model which is one of the most popular models for LLM-RAG applications and also one of the most widely used models.
* Other model to try:
  * `nomic-embed-text-v1`

* We will do the following
1. Load the model to finetune
2. (optional) model card data

In [12]:
## load model
model = SentenceTransformer("BAAI/bge-large-en")

modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

README.md:   0%|          | 0.00/90.3k [00:00<?, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/52.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/720 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.34G [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/366 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/711k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/125 [00:00<?, ?B/s]

1_Pooling/config.json:   0%|          | 0.00/191 [00:00<?, ?B/s]

## Training Arguments
* These are the SentenceTransformer training arguments for the embedding model finetuning.

In [13]:
## imports
from sentence_transformers import SentenceTransformerTrainer, SentenceTransformerTrainingArguments
from sentence_transformers.losses import MultipleNegativesRankingLoss
from sentence_transformers.training_args import BatchSamplers

### Loss Function - defined
* Loss functions quantify how well a model performs for a given batch of data, allowing an optimizer to update the model weights to produce more favorable (e.g. lower) loss values. This is the core of the training process.
* According to the SBERT docs, there is no single loss function that works best for all use-cases.
   * Instead, the loss function to use greatly depends on your available data and on your target task(s).
   * The `MultipleNegativesRankingLoss` is a great loss function if you only have positive pairs as it adds in batch negative samples to the loss function to have per sample n-1 negative samples.
      * source: https://www.philschmid.de/fine-tune-embedding-model-for-rag

In [14]:
## define loss function
loss = MultipleNegativesRankingLoss(model)

In [41]:
## training arguments
args = SentenceTransformerTrainingArguments(
    # Required parameters - local directory saving
    output_dir="models/mpnet-base-all-nli-triplet",
    # Optional train parameters
    num_train_epochs=1,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    warmup_ratio=0.1,
    fp16=True, # set to False if you get error that GPU cant run on FP16
    bf16=False, # set to True if you have a GPU that supports BF16
    batch_sampler=BatchSamplers.NO_DUPLICATES, #MultipleNegativesRankingLoss benefits from non duplicate samples in a batch
    # Optional tracking + debugging params:
    eval_strategy="steps",
    eval_steps=100,
    save_strategy="steps",
    save_steps=100,
    save_total_limit=2,
    logging_steps=100,
    run_name="mpnet-base-all-nli-triplet" # used if W&B is installed
    )

## Train Fine-Tuned Embedding Model

In [42]:
## setup HF trainer
trainer = SentenceTransformerTrainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    loss=loss
)

In [43]:
## train
trainer.train()

Step,Training Loss,Validation Loss
100,0.3082,0.680275


TrainOutput(global_step=125, training_loss=0.2810287742614746, metrics={'train_runtime': 73.2465, 'train_samples_per_second': 27.305, 'train_steps_per_second': 1.707, 'total_flos': 0.0, 'train_loss': 0.2810287742614746, 'epoch': 1.0})

## Evaluate the Model

In [44]:
from sentence_transformers.evaluation import TripletEvaluator


#setup test_evaluator
test_evaluator = TripletEvaluator(
    anchors=test_dataset['anchor'],
    positives=test_dataset['positive'],
    negatives=test_dataset['negative'],
    name='all-nli-test'

)

# pass test_evaluator to model
test_evaluator(model)

{'all-nli-test_cosine_accuracy': 1.0,
 'all-nli-test_dot_accuracy': 0.0,
 'all-nli-test_manhattan_accuracy': 0.95,
 'all-nli-test_euclidean_accuracy': 1.0,
 'all-nli-test_max_accuracy': 1.0}

Summary:
* Obviously cutting down the training data to 2000 rows and the eval and test to only 20 rows for the purposes of not overloading our GPU was for demonstration only and as a result the accuracy was quite high and the loss low.
* However, we can adapt and finetune this code to specific datasets and examples. This demonstrates that you can take the "cookie cutter" SentenceTransformer infrastructure and finetune it further to better learn the contextual nature of your text.

## Save model to HuggingFace

In [None]:
# Save trained model
# model.save_pretrained("models/mpnet-base-all-nli-triplet/final")

# push saved model to huggingface hub
# model.push_to_hub('mpnet-base-all-nli-triplet)