# Fine tune text embeddings
[basic](https://huggingface.co/blog/how-to-train-sentence-transformers)

# Imports

In [9]:
!pip install sentence-transformers datasets -qqq

In [6]:
import os
import pandas as pd
from sentence_transformers import InputExample, SentenceTransformer
from sentence_transformers import util

from datasets import load_dataset
from torch.utils.data import DataLoader

from sentence_transformers import losses

os.environ["PYTORCH_MPS_HIGH_WATERMARK_RATIO"] = "0.0"

# Datasets

- In all cases, negatives are implicitly created, whether we provide them or not. The negatives are created by the model itself, and are the other documents in the batch. This is why we don't need to provide negatives in the dataset. When we provide the explicit negatives, the model will use them instead of the implicit negatives.

| dataset_structure           | examples                                                                          | loss                                                                                                | application                                                           |
|-----------------------------|-----------------------------------------------------------------------------------|-----------------------------------------------------------------------------------------------------|-----------------------------------------------------------------------|
| <query, document, label (grade)>    | snli                                                                              | ContrastiveLoss; SoftmaxLoss; CosineSimilarityLoss                                                  | natural language inference (NLI) {entailment (positive), neutral, contradiciton (negative)} |
| <query, document (positive)>           | embedding-data/flickr30k_captions_quintets; embedding-data/coco_captions_quintets; embedding-data/sentence-compression | MultipleNegativesRankingLoss; MegaBatchMarginLoss                                                   | natural language inference (NLI) {entailment}                         |
| <query, class>              | trec; yahoo_answers_topics                                                        | BatchHardTripletLoss; BatchAllTripletLoss; BatchHardSoftMarginTripletLoss; BatchSemiHardTripletLoss |                                                                       |
| <query, document (positive), document (negative)>| embedding-data/QQP_triplets                                                       | TripletLoss;                                                                                        |                                                                       |

# Preprocessing: Dataset Type: Case 2: <query, document>

In [143]:
dataset_id = "embedding-data/sentence-compression"
dataset = load_dataset(dataset_id)


sample = dataset["train"][-1]

print(f"The {dataset_id} dataset has {dataset['train'].num_rows} examples.")
print(f"Each example is a {type(sample)} with a {type(sample['set'])} as value.")

print(f"Examples look like this: {sample}")

The embedding-data/sentence-compression dataset has 180000 examples.
Each example is a <class 'dict'> with a <class 'list'> as value.
Examples look like this: {'set': ['Two of the most annoying forms of musical expression might all too soon converge to the sound of shrieking, sophomoric orchestral crescendos and controversy.', 'Two most annoying forms of musical expression converge...']}


In [160]:
import re


def contains_topic(dataset, topic="sport|football|soccer"):
    return any(
        re.search(rf"\b({topic})\b", text, re.IGNORECASE) for text in dataset["set"]
    )


topic_dataset = dataset["train"].filter(contains_topic)
print(f"The dataset has {topic_dataset.num_rows} examples.")
print(f"Examples look like this: {topic_dataset[-1]}")

The dataset has 2045 examples.
Examples look like this: {'set': ['ACC Commissioner John Swofford shakes the hand of Notre Dame president Rev. John I. Jenkins after Notre Dame announced it would join the ACC. The Fighting Irish will maintain an independent football team.', 'Notre Dame joins the ACC']}


In [161]:
train_examples = []
train_data = topic_dataset["set"]
n_examples = topic_dataset.num_rows
max_samples = 10_000

for i in range(n_examples)[:max_samples]:
    example = train_data[i]
    train_examples.append(InputExample(texts=[example[0], example[1]]))

train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=64)

# Model

In [63]:
model_id = "embedding-data/distilroberta-base-sentence-transformer"
raw_model = SentenceTransformer(model_id)
positives_model = SentenceTransformer(model_id)

assert positives_model is not raw_model

# Train

In [64]:
loss = losses.MultipleNegativesRankingLoss(model=positives_model)
num_epochs = 10
warmup_steps = int(len(train_dataloader) * num_epochs * 0.1)

positives_model.fit(
    train_objectives=[(train_dataloader, loss)],
    epochs=num_epochs,
    warmup_steps=warmup_steps,
)

Iteration: 100%|██████████| 32/32 [02:12<00:00,  4.15s/it]
Iteration: 100%|██████████| 32/32 [01:13<00:00,  2.30s/it]
Iteration: 100%|██████████| 32/32 [01:51<00:00,  3.49s/it]
Iteration: 100%|██████████| 32/32 [01:48<00:00,  3.39s/it]
Iteration: 100%|██████████| 32/32 [01:53<00:00,  3.54s/it]
Iteration: 100%|██████████| 32/32 [02:31<00:00,  4.72s/it]
Iteration: 100%|██████████| 32/32 [02:33<00:00,  4.80s/it]
Iteration: 100%|██████████| 32/32 [02:07<00:00,  3.98s/it]
Iteration: 100%|██████████| 32/32 [01:44<00:00,  3.27s/it]
Iteration: 100%|██████████| 32/32 [01:46<00:00,  3.33s/it]
Epoch: 100%|██████████| 10/10 [19:43<00:00, 118.31s/it]


# Evaluate

In [79]:
def predict(model):
    query = "I love playing soccer by the beach, it's fun! I play the midfield position, and love passing the ball. My favorite team is Liverpool and try to play like them."
    documents = [
        "I love Liverpool!",
        "Midfielders love passing the ball",
        "I love playing volleyball while I'm at the beach",
        "I enjoy watching sports",
        "I enjoy cooking dinner for my family",
        "I enjoy dancing, and live music on a night out at the bar",
        "I love dogs!",
    ]

    query_embedding = model.encode([query])
    document_embeddings = model.encode(documents)

    similarities = util.cos_sim(query_embedding, document_embeddings)

    print(f"Query: {query}\n")
    results = sorted(
        zip(documents, similarities[0].tolist()), key=lambda x: x[1], reverse=True
    )
    return pd.DataFrame(results, columns=["document", "similarity"])

In [176]:
print("Embedding similarity from the raw model:")
display(predict(raw_model))

print("Embedding similarity from the fine-tuned model:")
display(predict(positives_model))

Embedding similarity from the raw model:
Query: I love playing soccer by the beach, it's fun! I play the midfield position, and love passing the ball. My favorite team is Liverpool and try to play like them.



Unnamed: 0,document,similarity
0,I love playing volleyball while I'm at the beach,0.850552
1,Midfielders love passing the ball,0.795377
2,I love Liverpool!,0.795222
3,"I enjoy dancing, and live music on a night out...",0.793633
4,I enjoy watching sports,0.756239
5,I love dogs!,0.752807
6,I enjoy cooking dinner for my family,0.717749


Embedding similarity from the fine-tuned model:
Query: I love playing soccer by the beach, it's fun! I play the midfield position, and love passing the ball. My favorite team is Liverpool and try to play like them.



Unnamed: 0,document,similarity
0,I love playing volleyball while I'm at the beach,0.695838
1,I love Liverpool!,0.622702
2,Midfielders love passing the ball,0.571108
3,I enjoy watching sports,0.494715
4,"I enjoy dancing, and live music on a night out...",0.442614
5,I love dogs!,0.408467
6,I enjoy cooking dinner for my family,0.267176


# Preprocessing: Dataset Type: Case 4: <query, document (positive), document (negative)>

In [145]:
dataset_id = "embedding-data/QQP_triplets"
triplet_dataset = load_dataset(dataset_id)

sample = triplet_dataset["train"][-1]

print(f"The {dataset_id} dataset has {triplet_dataset['train'].num_rows} examples.")
print(f"Each example is a {type(sample)} with a {type(sample['set'])} as value.")

print(f"Examples look like this: {sample}")
print(f"Positives: {len(sample.get('set').get('pos'))}")
print(f"Negatives: {len(sample.get('set').get('neg'))}")

The embedding-data/QQP_triplets dataset has 101762 examples.
Each example is a <class 'dict'> with a <class 'dict'> as value.
Examples look like this: {'set': {'query': 'Why do you use an iPhone?', 'pos': ['Why do people buy the iPhone?'], 'neg': ["Why shouldn't I buy an iPhone?", 'Why is iPhone so expensive?', 'Why are iPhones so expensive?', 'Why iphone are so costly?', 'Why are iPhones costly?', 'Is the iPhone really more expensive? Why or why not?', 'Why people are madly buying iPhone 4 in India, given that it is a more than 3-year-old hardware?', 'Why should I not buy the iPhone 5?', 'Why should I not buy an iPhone 7?', 'Why do some people prefer iPhones to Androids?', 'What are the reasons why people buy Samsung phones?', 'Why are iPhone users so loyal to the brand?', 'Why is the iPhone 6 so expensive?', 'Are iPhones seriously worth the price?', 'Are Apple iPhones worth the price?', 'Why is the iPhone 6s so expensive?', 'Is the iPhone really worth its price?', 'Is iPhone really w

In [156]:
def does_not_contain_topic(dataset, topic="football|soccer"):
    return all(
        not re.search(rf"\b({topic})\b", text, re.IGNORECASE) for text in dataset["set"]
    )


non_topic_dataset = dataset["train"].filter(does_not_contain_topic)
print(f"The dataset has {non_topic_dataset.num_rows} examples.")
print(f"Examples look like this: {non_topic_dataset[-1]}")

Filter:   0%|          | 0/180000 [00:00<?, ? examples/s]

Filter: 100%|██████████| 180000/180000 [00:01<00:00, 111526.65 examples/s]

The dataset has 178225 examples.
Examples look like this: {'set': ['Two of the most annoying forms of musical expression might all too soon converge to the sound of shrieking, sophomoric orchestral crescendos and controversy.', 'Two most annoying forms of musical expression converge...']}





In [162]:
from datasets import Dataset

triplets = []

for p, n in zip(topic_dataset["set"], non_topic_dataset["set"]):
    query = p[0]
    pos = p[-1]
    neg = n[-1]
    triplets.append({"query": query, "pos": pos, "neg": [neg]})

triplet_topic_dataset = Dataset.from_pandas(pd.DataFrame(data=triplets))
print(f"The dataset has {triplet_topic_dataset.num_rows} examples.")
print(f"Examples look like this: {triplet_topic_dataset[-1]}")

The dataset has 2045 examples.
Examples look like this: {'query': 'ACC Commissioner John Swofford shakes the hand of Notre Dame president Rev. John I. Jenkins after Notre Dame announced it would join the ACC. The Fighting Irish will maintain an independent football team.', 'pos': 'Notre Dame joins the ACC', 'neg': ['France Telecom withdraws TeliaSonera offer']}


In [170]:
train_examples = []
train_data = triplet_topic_dataset
n_examples = triplet_topic_dataset.num_rows
max_samples = 10_000

for i in range(n_examples)[:max_samples]:
    example = train_data[i]
    train_examples.append(
        InputExample(texts=[example["query"], example["pos"][0], example["neg"][0]])
    )

train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=64)

# Model

In [171]:
model_id = "embedding-data/distilroberta-base-sentence-transformer"
triplet_model = SentenceTransformer(model_id)

assert model is not raw_model

# Train

In [172]:
loss = losses.TripletLoss(model=triplet_model)
num_epochs = 10
warmup_steps = int(len(train_dataloader) * num_epochs * 0.1)

triplet_model.fit(
    train_objectives=[(train_dataloader, loss)],
    epochs=num_epochs,
    warmup_steps=warmup_steps,
)

Iteration: 100%|██████████| 32/32 [02:50<00:00,  5.33s/it]
Iteration: 100%|██████████| 32/32 [01:39<00:00,  3.11s/it]
Iteration: 100%|██████████| 32/32 [01:48<00:00,  3.39s/it]
Iteration: 100%|██████████| 32/32 [01:48<00:00,  3.40s/it]
Iteration: 100%|██████████| 32/32 [01:43<00:00,  3.23s/it]
Iteration: 100%|██████████| 32/32 [01:54<00:00,  3.58s/it]
Iteration: 100%|██████████| 32/32 [01:35<00:00,  2.99s/it]
Iteration: 100%|██████████| 32/32 [01:47<00:00,  3.37s/it]
Iteration: 100%|██████████| 32/32 [01:35<00:00,  2.98s/it]
Iteration: 100%|██████████| 32/32 [01:43<00:00,  3.24s/it]
Epoch: 100%|██████████| 10/10 [18:27<00:00, 110.78s/it]


# Evaluate

In [175]:
print("Embedding similarity from the raw model:")
display(predict(raw_model))

print("Embedding similarity from the fine-tuned model (positives):")
display(predict(positives_model))

print("Embedding similarity from the fine-tuned model (triplets):")
display(predict(triplet_model))

Embedding similarity from the raw model:
Query: I love playing soccer by the beach, it's fun! I play the midfield position, and love passing the ball. My favorite team is Liverpool and try to play like them.



Unnamed: 0,document,similarity
0,I love playing volleyball while I'm at the beach,0.850552
1,Midfielders love passing the ball,0.795377
2,I love Liverpool!,0.795222
3,"I enjoy dancing, and live music on a night out...",0.793633
4,I enjoy watching sports,0.756239
5,I love dogs!,0.752807
6,I enjoy cooking dinner for my family,0.717749


Embedding similarity from the fine-tuned model (positives):
Query: I love playing soccer by the beach, it's fun! I play the midfield position, and love passing the ball. My favorite team is Liverpool and try to play like them.



Unnamed: 0,document,similarity
0,I love playing volleyball while I'm at the beach,0.695838
1,I love Liverpool!,0.622702
2,Midfielders love passing the ball,0.571108
3,I enjoy watching sports,0.494715
4,"I enjoy dancing, and live music on a night out...",0.442614
5,I love dogs!,0.408467
6,I enjoy cooking dinner for my family,0.267176


Embedding similarity from the fine-tuned model (triplets):
Query: I love playing soccer by the beach, it's fun! I play the midfield position, and love passing the ball. My favorite team is Liverpool and try to play like them.



Unnamed: 0,document,similarity
0,Midfielders love passing the ball,0.388242
1,I enjoy watching sports,0.3425
2,"I enjoy dancing, and live music on a night out...",0.288263
3,I love playing volleyball while I'm at the beach,0.198876
4,I love dogs!,0.130319
5,I love Liverpool!,0.1116
6,I enjoy cooking dinner for my family,-0.054029
