# Finetuning SBERT for Semantic search using MNR loss

In this notebook we will finetune the bert-base model for semantic search using Multiple Negative Ranking loss.

This will be mostly similar to the finetuning we did in the previous notebook. The main changes are:
- We won't use a fully connected layer on top of the embeddings now
- We will use Multiple Negative Ranking loss
- We will compute cosine similarity between the pooled *u* and *v* embedding and use that in MNR

> This loss expects as input a batch consisting of sentence pairs (a_1, p_1), (a_2, p_2)..., (a_n, p_n)
where we assume that (a_i, p_i) are a positive pair and (a_i, p_j) for i!=j a negative pair.
>
> For each a_i, it uses all other p_j as negative samples, i.e., for a_i, we have 1 positive example (p_i) and
n-1 negative examples (p_j). It then minimizes the negative log-likehood for softmax normalized scores.
This loss function works great to train embeddings for retrieval setups where you have positive pairs (e.g. (query, relevant_doc)) as it will sample in each batch n-1 negative docs randomly.
>
> The performance usually increases with increasing batch sizes.
>
> For more information, see: https://arxiv.org/pdf/1705.00652.pdf<br>
(Efficient Natural Language Response Suggestion for Smart Reply, Section 4.4)
>
> You can also provide one or multiple hard negatives per anchor-positive pair by structering the data like this:
(a_1, p_1, n_1), (a_2, p_2, n_2)
>
> Here, n_1 is a hard negative for (a_1, p_1). The loss will use for the pair (a_i, p_i) all p_j (j!=i) and all n_j as negatives.
>
> Source: https://github.com/UKPLab/sentence-transformers/blob/master/sentence_transformers/losses/MultipleNegativesRankingLoss.py

We will just use the (anchor, postive) pairs in our training here. We will discuss hard negatives later in the series.

Since we just need the (achor, positive) pairs, we will filter the snli dataset with label=0 for `entails`.

In [1]:
import datasets

dataset = datasets.load_dataset('snli', split='train')

# there are some pairs of "premise" and "hypothesis" which haven't been
# labeled in this dataset, we will filter those out first
dataset = dataset.filter(lambda x: True if x["label"] == 0 else False)

len(dataset), dataset[0]

Using the latest cached version of the module from /home/utsav/.cache/huggingface/modules/datasets_modules/datasets/snli/1f60b67533b65ae0275561ff7828aad5ee4282d0e6f844fd148d05d3c6ea251b (last modified on Thu Jul 21 14:15:05 2022) since it couldn't be found locally at snli., or remotely on the Hugging Face Hub.
Reusing dataset snli (/home/utsav/.cache/huggingface/datasets/snli/plain_text/1.0.0/1f60b67533b65ae0275561ff7828aad5ee4282d0e6f844fd148d05d3c6ea251b)


  0%|          | 0/551 [00:00<?, ?ba/s]

(183416,
 {'premise': 'A person on a horse jumps over a broken down airplane.',
  'hypothesis': 'A person is outdoors, on a horse.',
  'label': 0})

Let's look at an example of calculating MLR for a batch

In [13]:
import torch

batch_size = 16

anchors, positives = [], []
for i in range(batch_size):
    anchors.append(dataset[i]["premise"])
    positives.append(dataset[i]["hypothesis"])

In [14]:
anchors

['A person on a horse jumps over a broken down airplane.',
 'Children smiling and waving at camera',
 'A boy is jumping on skateboard in the middle of a red bridge.',
 'Two blond women are hugging one another.',
 'A few people in a restaurant setting, one of them is drinking orange juice.',
 'An older man is drinking orange juice at a restaurant.',
 'A man with blond-hair, and a brown shirt drinking out of a public water fountain.',
 'Two women who just had lunch hugging and saying goodbye.',
 'Two women, holding food carryout containers, hug.',
 'A Little League team tries to catch a runner sliding into a base in an afternoon game.',
 'The school is having a special event in order to show the american culture on how other cultures are dealt with in parties.',
 'High fashion ladies wait outside a tram beside a crowd of people in the city.',
 'A man, woman, and child enjoying themselves on a beach.',
 'People waiting to get on a train or just getting off.',
 'People waiting to get on a 

In [15]:
positives

['A person is outdoors, on a horse.',
 'There are children present',
 'The boy does a skateboarding trick.',
 'There are women showing affection.',
 'The diners are at a restaurant.',
 'A man is drinking juice.',
 'A blond man drinking water from a fountain.',
 'There are two woman in this picture.',
 'Two women hug each other.',
 'A team is trying to tag a runner out.',
 'A school is hosting an event.',
 'Women are waiting by a tram.',
 'A family of three is at the beach.',
 'There are people just getting on a train',
 'There are people waiting on a train.',
 'A couple are playing with a young child outside.']

In [None]:
from sentence_transformers import SentenceTransformers, util

model = SentenceTransformers("")
loss_fn = torch.nn.CrossEntropyLoss()

anchor_encodings = model.encode(anchors)
postiive_encodings = model.encode(positives)

similarity_matrix = util.cos_sim(anchor_embeddings, positive_embeddings)
target = torch.tensor(range(len(anchor_embeddings)), dtype=torch.long)

loss_fn(similarity_matrix, target)