# Train and Fine-Tune Sentence Transformers Models - Notebook Companion

In [None]:
%%capture
!pip install sentence-transformers

## How Sentence Transformers models work


In [None]:
from sentence_transformers import SentenceTransformer, models

## Step 1: use an existing language model
word_embedding_model = models.Transformer('distilroberta-base')

## Step 2: use a pool function over the token embeddings
pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension())

## Join steps 1 and 2 using the modules argument
model = SentenceTransformer(modules=[word_embedding_model, pooling_model])

Downloading config.json:   0%|          | 0.00/480 [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/316M [00:00<?, ?B/s]

Some weights of the model checkpoint at distilroberta-base were not used when initializing RobertaModel: ['lm_head.bias', 'lm_head.dense.weight', 'lm_head.layer_norm.bias', 'lm_head.decoder.weight', 'lm_head.dense.bias', 'lm_head.layer_norm.weight']
- This IS expected if you are initializing RobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Downloading vocab.json:   0%|          | 0.00/878k [00:00<?, ?B/s]

Downloading merges.txt:   0%|          | 0.00/446k [00:00<?, ?B/s]

Downloading tokenizer.json:   0%|          | 0.00/1.29M [00:00<?, ?B/s]

## How to prepare your dataset for training a Sentence Transformers model


In [None]:
%%capture
!pip install datasets

In [None]:
from datasets import load_dataset

dataset_id = "embedding-data/QQP_triplets"
# dataset_id = "embedding-data/sentence-compression"

dataset = load_dataset(dataset_id)

Using custom data configuration embedding-data--QQP_triplets-ff67885711b8d7f7


Downloading and preparing dataset json/embedding-data--QQP_triplets to /root/.cache/huggingface/datasets/embedding-data___json/embedding-data--QQP_triplets-ff67885711b8d7f7/0.0.0/a3e658c4731e59120d44081ac10bf85dc7e1388126b92338344ce9661907f253...


Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/183M [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

0 tables [00:00, ? tables/s]

Dataset json downloaded and prepared to /root/.cache/huggingface/datasets/embedding-data___json/embedding-data--QQP_triplets-ff67885711b8d7f7/0.0.0/a3e658c4731e59120d44081ac10bf85dc7e1388126b92338344ce9661907f253. Subsequent calls will reuse this data.


  0%|          | 0/1 [00:00<?, ?it/s]

In [None]:
print(f"- The {dataset_id} dataset has {dataset['train'].num_rows} examples.")
print(f"- Each example is a {type(dataset['train'][0])} with a {type(dataset['train'][0]['set'])} as value.")
print(f"- Examples look like this: {dataset['train'][0]}")

- The embedding-data/QQP_triplets dataset has 101762 examples.
- Each example is a <class 'dict'> with a <class 'dict'> as value.
- Examples look like this: {'set': {'query': 'Why in India do we not have one on one political debate as in USA?', 'pos': ['Why cant we have a public debate between politicians in India like the one in US?'], 'neg': ['Can people on Quora stop India Pakistan debate? We are sick and tired seeing this everyday in bulk?', 'Why do politicians, instead of having a decent debate on issues going in and around the world, end up fighting always?', 'Can educated politicians make a difference in India?', 'What are some unusual aspects about politics and government in India?', 'What is debate?', 'Why does civic public communication and discourse seem so hollow in modern India?', 'What is a Parliamentary debate?', "Why do we always have two candidates at the U.S. presidential debate. yet the ballot has about 7 candidates? Isn't that a misrepresentation of democracy?", 'Wh

Convert the examples into `InputExample`s. It might around 10 minutes in Google Colab.

In [None]:
from tqdm.auto import tqdm
from sentence_transformers import InputExample

train_examples = []
n_examples = 1000 
## For training with the entire dataset you can use `for i in range(dataset['train'].num_rows):`

for i in tqdm(range(n_examples)):
  example = dataset['train']['set'][i]
  train_examples.append(InputExample(texts=[example['query'], example['pos'][0], example['neg'][0]]))
  # Print each 50 examples how the example looks
  if i % 50 == 0:
    print(f"Anchor: {example['query']} --- Positive: {example['pos'][0]} --- Negative: {example['neg'][0]}")

  0%|          | 0/100 [00:00<?, ?it/s]

Anchor: Why in India do we not have one on one political debate as in USA? --- Positive: Why cant we have a public debate between politicians in India like the one in US? --- Negative: Can people on Quora stop India Pakistan debate? We are sick and tired seeing this everyday in bulk?
Anchor: When will be end of world? --- Positive: What is the end of this world? --- Negative: Where does the world end?


In [None]:
print(f"We have a {type(train_examples)} of length {len(train_examples)} containing {type(train_examples[0])}'s.")

We have a <class 'list'> of length 100 containing <class 'sentence_transformers.readers.InputExample.InputExample'>'s.


We wrap our training dataset into a Pytorch `Dataloader` to shuffle examples and get batch sizes.

In [None]:
from torch.utils.data import DataLoader

train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=16)

## Loss functions for training a Sentence Transformers model


In [None]:
from sentence_transformers import losses

train_loss = losses.TripletLoss(model=model)

## How to train a Sentence Transformer model


In [None]:
num_epochs = 10

warmup_steps = int(len(train_dataloader) * num_epochs * 0.1) #10% of train data

In [None]:
model.fit(train_objectives=[(train_dataloader, train_loss)],
          epochs=num_epochs,
          warmup_steps=warmup_steps) 

Epoch:   0%|          | 0/10 [00:00<?, ?it/s]

Iteration:   0%|          | 0/7 [00:00<?, ?it/s]

Iteration:   0%|          | 0/7 [00:00<?, ?it/s]

Iteration:   0%|          | 0/7 [00:00<?, ?it/s]

Iteration:   0%|          | 0/7 [00:00<?, ?it/s]

Iteration:   0%|          | 0/7 [00:00<?, ?it/s]

Iteration:   0%|          | 0/7 [00:00<?, ?it/s]

Iteration:   0%|          | 0/7 [00:00<?, ?it/s]

Iteration:   0%|          | 0/7 [00:00<?, ?it/s]

Iteration:   0%|          | 0/7 [00:00<?, ?it/s]

Iteration:   0%|          | 0/7 [00:00<?, ?it/s]

## How to share a Sentence Transformers to the Hugging Face Hub

In [None]:
!huggingface-cli login


        _|    _|  _|    _|    _|_|_|    _|_|_|  _|_|_|  _|      _|    _|_|_|      _|_|_|_|    _|_|      _|_|_|  _|_|_|_|
        _|    _|  _|    _|  _|        _|          _|    _|_|    _|  _|            _|        _|    _|  _|        _|
        _|_|_|_|  _|    _|  _|  _|_|  _|  _|_|    _|    _|  _|  _|  _|  _|_|      _|_|_|    _|_|_|_|  _|        _|_|_|
        _|    _|  _|    _|  _|    _|  _|    _|    _|    _|    _|_|  _|    _|      _|        _|    _|  _|        _|
        _|    _|    _|_|      _|_|_|    _|_|_|  _|_|_|  _|      _|    _|_|_|      _|        _|    _|    _|_|_|  _|_|_|_|

        To login, `huggingface_hub` now requires a token generated from https://huggingface.co/settings/tokens .
        
Token: 
Login successful
Your token has been saved to /root/.huggingface/token
[1m[31mAuthenticated through git-credential store but this isn't the helper defined on your machine.
You might have to re-authenticate when pushing to the Hugging Face Hub. Run the following command in yo

In [None]:
model.save_to_hub(
    "distilroberta-base-sentence-transformer", 
    organization="embedding-data",
    train_datasets=["embedding-data/QQP_triplets"],
    exist_ok=True, 
    )

Cloning https://huggingface.co/embedding-data/distilroberta-base-sentence-transformer into local empty directory.


Upload file pytorch_model.bin:   0%|          | 3.34k/313M [00:00<?, ?B/s]

To https://huggingface.co/embedding-data/distilroberta-base-sentence-transformer
   0e74c10..8c082cd  main -> main



'https://huggingface.co/embedding-data/distilroberta-base-sentence-transformer/commit/8c082cdedb8acb7788055e9a8f06c279c68e93dc'

## Extra: How to fine-tune a Sentence Transformer model


Now we will fine-tune our Sentence Transformer model.

In [None]:
modelB = SentenceTransformer('embedding-data/distilroberta-base-sentence-transformer')

Downloading:   0%|          | 0.00/1.40k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/190 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/4.04k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/671 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/124 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/329M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/280 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/2.11M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/386 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/798k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/229 [00:00<?, ?B/s]

In [None]:
dataset_id = "embedding-data/sentence-compression"
datasetB = load_dataset(dataset_id)

Using custom data configuration embedding-data--sentence-compression-a90dfb3e5e100cf9


Downloading and preparing dataset json/embedding-data--sentence-compression to /root/.cache/huggingface/datasets/embedding-data___json/embedding-data--sentence-compression-a90dfb3e5e100cf9/0.0.0/a3e658c4731e59120d44081ac10bf85dc7e1388126b92338344ce9661907f253...


Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/14.2M [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

0 tables [00:00, ? tables/s]

Dataset json downloaded and prepared to /root/.cache/huggingface/datasets/embedding-data___json/embedding-data--sentence-compression-a90dfb3e5e100cf9/0.0.0/a3e658c4731e59120d44081ac10bf85dc7e1388126b92338344ce9661907f253. Subsequent calls will reuse this data.


  0%|          | 0/1 [00:00<?, ?it/s]

In [None]:
print(f"Examples look like this: {datasetB['train']['set'][0]}")

Examples look like this: ["The USHL completed an expansion draft on Monday as 10 players who were on the rosters of USHL teams during the 2009-10 season were selected by the League's two newest entries, the Muskegon Lumberjacks and Dubuque Fighting Saints.", 'USHL completes expansion draft']


In [None]:
train_examplesB = []
n_examples = 1500 
## For training with the entire dataset you can use `for i in range(dataset['train'].num_rows):`

for i in tqdm(range(n_examples)):
  example = datasetB['train']['set'][i]
  train_examplesB.append(InputExample(texts=[example[0], example[1]]))
  # Print each 50 examples how the example looks
  if i % 50 == 0:
    print(f"Anchor: {example[0]} --- Positive: {example[1]}")

  0%|          | 0/1500 [00:00<?, ?it/s]

Anchor: The USHL completed an expansion draft on Monday as 10 players who were on the rosters of USHL teams during the 2009-10 season were selected by the League's two newest entries, the Muskegon Lumberjacks and Dubuque Fighting Saints. --- Positive: USHL completes expansion draft
Anchor: Motorola has just unveiled its gallery of conceptual handsets from the bowels of its design studios in order to celebrate 25 years of the cell phone, and we wish Motorola the best in whatever endeavors they embark upon with their future handset designs. --- Positive: Motorola celebrates 25 years of cell phones
Anchor: New Hampshire State Police say a man was found dead from a self-inflicted gunshot Wednesday afternoon following an hours-long standoff in the town of Mason. --- Positive: Man found dead following standoff
Anchor: As a former drunken sailor, I quit when I ran out of money. --- Positive: Drunken sailor runs out of money
Anchor: That's right, bath salt is the new cocaine and offers intense

In [None]:
train_dataloaderB = DataLoader(train_examplesB, shuffle=True, batch_size=64)
train_lossB = losses.MultipleNegativesRankingLoss(model=modelB)
num_epochsB = 10
warmup_stepsB = int(len(train_dataloaderB) * num_epochsB * 0.1) #10% of train data

In [None]:
model.fit(train_objectives=[(train_dataloaderB, train_lossB)],
          epochs=num_epochsB,
          warmup_steps=warmup_stepsB) 

Epoch:   0%|          | 0/10 [00:00<?, ?it/s]

Iteration:   0%|          | 0/24 [00:00<?, ?it/s]

Iteration:   0%|          | 0/24 [00:00<?, ?it/s]

Iteration:   0%|          | 0/24 [00:00<?, ?it/s]

Iteration:   0%|          | 0/24 [00:00<?, ?it/s]

Iteration:   0%|          | 0/24 [00:00<?, ?it/s]

Iteration:   0%|          | 0/24 [00:00<?, ?it/s]

Iteration:   0%|          | 0/24 [00:00<?, ?it/s]

Iteration:   0%|          | 0/24 [00:00<?, ?it/s]

Iteration:   0%|          | 0/24 [00:00<?, ?it/s]

Iteration:   0%|          | 0/24 [00:00<?, ?it/s]

In [None]:
model.save_to_hub(
    "distilroberta-base-sentence-transformer", 
    organization="embedding-data",
    train_datasets=["embedding-data/sentence-compression"],
    exist_ok=True, 
    )

Cloning https://huggingface.co/embedding-data/distilroberta-base-sentence-transformer into local empty directory.


Download file pytorch_model.bin:   0%|          | 3.47k/313M [00:00<?, ?B/s]

Clean file pytorch_model.bin:   0%|          | 1.00k/313M [00:00<?, ?B/s]

To https://huggingface.co/embedding-data/distilroberta-base-sentence-transformer
   8c082cd..ea459f9  main -> main



'https://huggingface.co/embedding-data/distilroberta-base-sentence-transformer/commit/ea459f9bd02c058dde7a51035cde4231af73cc25'