# Contrastive Learning Notes

SBERT (Sentence Bert) [[5]](https://arxiv.org/abs/1908.10084) demonstrated it is beneficial to perform additional fine-tuning on top of pre-trained transformer based language models for competitive sentence embedding performance. From a practical stand point, sentence embeddings are particular useful in:

- Retrieval tasks where the typical setup is a bi-encoder, a.k.a. twin tower architecture model. These architecture accepts individual entity as inputs, and comparing them with future inputs for retrieving "similar" entities (definition of similar is use case dependent). This enables pre-computing embeddings and caching for retrieving "similar" entities through fast approximate nearest neighborhood look ups which is critical for latency sensitive applications.
- Embedding-based classification tasks, where the embeddings are fed in as features to downstream models. This is different from a typical fine-tuning setup, here the embeddings, once generated are considered frozen, and won't be tuned along with the downstream system. These are places where downstream application relies on non-deep learning models such as gradient boosted tree as their choice of machine learning algorithm.

Example industrial use case: Facebook search's embedding based retrieval [[15]](https://arxiv.org/abs/2006.11632).

## Key Recipes

As with most other use cases, we can of course use a more powerful encoder to generate the embedding representations for our anchor and positive, but some tips specific to improving the performance for contrastive loss based learning involves:

**Noise Contrastive Objective Function**

In recent years, most of these fine-tuning procedure leverage contrastive learning via variants of InfoNCE (Noise Contrastive Estimation) loss [[11]](https://arxiv.org/abs/1807.03748). This type of loss is sometimes referred to as NT-Xent (normalized temperature scaled cross entropy loss) in SimCLR (simple contrastiv learning of visual representations) [[12]](https://arxiv.org/abs/2002.05709) or multiple negative ranking loss [[2]](https://www.sbert.net/docs/package_reference/losses.html#multiplenegativesrankingloss), which uses cross entropy loss to distinguish positive from negative pairs.

\begin{align}
L = -\frac{1}{n} \sum^n_{i=1} \frac{exp(sim(a_i, p_i) / \tau)}{\sum_j exp(sim(a_i, p_j) / \tau)}
\end{align}

For training this loss function, we need to have pairs of anchors and its corresponding positive example ($a_1$, $p_1$). What we wish to accomplish is $a_i$ and $p_i$ becomes close in the vector space, whereas $a_i$ and some random examples $p_j$ becomes distant in vector space. $sim$ here represents a similarity function such as cosine similarity or dot product, $\tau$ denotes temperature scaling which can be a learned parameter, or configurable value. Note if choosing cosine similarity as similarity function, we might need to set a smaller value for temperature. As by default, cosine similarity's score differences are too small and doesn't lead to good empirical results.

PyTorch style pseudocode for this loss is also shown below:

```python
# @ denotes matrix multiplication, and the similarity metric here is a dot product
scores = embedding_a @ embedding_p.T / temperature
labels = torch.arange(len(scores), device=scores.device, dtype=torch.long)
cross_entropy_loss(scores, labels)
```

`embedding_a` and `embedding_p` are embedding representation of anchor and positive, in the shape of `[batch size, hidden size]`. They generated by backbone encoder (e.g. transformer) with or without linear projection of our choice. The encoder can even share weights, which is commonly referred to as siamese network.

**Larger In Batch Negatives:**

In batch negatives is widely used for training models with contrastive loss. Assuming there are $B$ positive pairs for a given mini batch, each of these positive pairs can be paired with $B - 1$ negatives (rest of the ' positive passages). This paradigm allows us to leverage the already loaded mini batch rather than using additional resources to sample negative examples. When relying on in batch negative sampling, using a large batch size is key. As larger batch size allows the loss function to optimize over a more diverse set of negative samples. i.e. it's easier to find the right answer over a pool of 2 candidates, versus say 1024 candidates. This can be treated as an implicit version of hard negative mining.

Work such as RocketQA [[6]](https://arxiv.org/abs/2010.08191) and CLIP (Contrastive Language-Image Pre-Training) [[8]](https://arxiv.org/abs/2103.00020) further mentions the use of cross-gpu negatives, where when training on multiple GPUs, the calculation of passage embedding can be sharded within each single GPU, these passage embeddings can then be shared among all the GPUs and serves as additional negative examples. i.e. for A GPUs, we can now collect $A \times B - 1$ negatives. Note, sharing here refers to a differentiable all gather operation. With this approach CLIP reports to be using a effective batch size as large as 32,768 sharded acrosss 256 GPUs, it should be noted that the optimal size will be dependent on training data size, where CLIP's training data consists of more than 400M examples.

**Data Augmentation with Denoised Negatives**

While increasing our in batch negative sampling's batch size can increase the number of negative samples, a lot of them might be easy negatives that can be quickly discerned by our model, as a result they are potentially not providing too much additional information. Hence, we will need a mechanism to find hard negatives. In other words instead of providing pairs of anchors, $a_i$, and positives, $p_i$ as our input data. We now provide triplets $a_i, p_i, n_i$, where negative $n_i$ should be similar to $p_i$ but not match with $a_i$. The primary strategies are:

- Leveraging structure from our data [[1]](https://www.youtube.com/watch?v=RHXZKUr8qOY), this strategy relies on our creativity and domain knowledge. For example:
    - For stackexchange question and answering dataset that contains sub-forums talking about programming, travel, cooking, creating pairs from each sub-forums while likely yield higher quality batches.
    - For a stackoverflow question and answering dataset, we can take answers with many upvotes and the positive sample and answers without any upvotes as hard negatives.
    - Or let's say we have a paper citation dataset, given a seed paper representing our anchor example, we can use the seed paper's cited paper as positive, while the paper that is cited by our cited paper, but not cited by the seed paper acting as negatives.
    - If we are working with website logs, using users engagement such as impressions, clicks or purchases [[15]](https://arxiv.org/abs/2006.11632).
- Algorithmically generate them.
    - RocketQA [[6]](https://arxiv.org/abs/2010.08191), Augmented SBERT [[7]](https://arxiv.org/abs/2010.08240), DPR (Dense Passage Retrieval) [[9]](https://arxiv.org/abs/2004.04906) suggests mining hard negatives using BM25 or a trained bi-encoder to generate semantically similar hard negatives. This is likely to perform better than performing lexical edits like insert/swap/delete/synonym replace [[4]](https://github.com/makcedward/nlpaug).
    - For images, there're plethora of image augmentation techniques, e.g. random cropping, resize, color distortion, gaussian blur, etc. All are methods to transform pixels while preserving the semantic meaning of an image's content such as its class labels. In un-supervised context such as SimCLR [[12]](https://arxiv.org/abs/2002.05709), where they rely on data augmentations to create positive pairs, they argue that stronger data augmentation is needed for contrastive learning to learn strong representations compared to the supervised counterparts, where the composite augmentation random cropping and color distortion was shown to stand out. 

With this strategy, we need to be mindful and ensure these examples are actually negatives. It is typically infeasible to scan through our entire dataset and label all the positive examples for a given anchor. Hence it can happen when sampling hard negatives, we might accidentally sample a positive example that wasn't labeled, introducing false negatives to the mix. To solve for this, we can train a separate cross-encoder model, which are typically more powerful at capturing semantic similarity compared to bi-encoder to denoise our hard negatives. In other words, when sampling hard negatives from the top ranked examples using aforementioned strategies, we can only select the ones that are predicted as negatives by the cross encoder with high confidence score.

Training with a mix of both random, hard negatives is often times beneficial, what's the optimal proportion of the two is something we'll have to experiment with on our use case. The overall data augmentation workflow can be roughly summarized into the following steps [[6]](https://arxiv.org/abs/2010.08191) [[7]](https://arxiv.org/abs/2010.08240):

- First train both a bi-encoder and cross-encoder model.
- We select additional input pairs and use our cross-encoder to label new input pairs, i.e. generate soft labels. Selecting suitable pairs is crucial for this augmentation strategy's success, and simply combining random pairs may lead to suboptimal downstream performance.
- These additional pairs are added to the training set.
- Fine-tune a new bi-encoder model on this larger augmented training dataset.
- Rinse and repeat.

Note, in pure image field work such as Moco (Momentum Contrast) [[14]](https://arxiv.org/abs/2104.02057) includes the concept of momentum encoder, it is not elaborated upon here as it is a unsupervised image setting, whereas text settings work such as E5 (EmbEddings from bidirEctional Encoder rEpresentations) [[13]](https://arxiv.org/abs/2212.03533) claims training with bigger batch size is more stable and results in no performance difference.

## Public Dataset & Benchmark

We can refer to public pre-trained sentence transformer's model card [[3]](https://huggingface.co/sentence-transformers/all-MiniLM-L12-v2), and E5 [[13]](https://arxiv.org/abs/2212.03533) on public datasets that can be combined and used to fine-tune these models .

As with all methods, it is important to find an established benchmark dataset so we can quickly iterate on new ideas. MTEB (Massive Text Embedding Benchmark) [[10]](https://arxiv.org/abs/2210.07316), has collected 8 embedding tasks ranging from semantic textual similarity (STS, SemEval), classification (fine tuning a classifier using the embedding as input features, SentEval), information retrieval (BEIR), etc., in total it consists of 56 datasets, covering 112 languages. They also evaluated 30 different models to provide a holistic view of state of the art public pre-trained text embedding models.

# References

- [[1]](https://www.youtube.com/watch?v=RHXZKUr8qOY) Youtube: Training State of the Art Sentence Embedding Models
- [[2]](https://www.sbert.net/docs/package_reference/losses.html#multiplenegativesrankingloss) SBERT Documentation - Multiple Negatives Ranking Loss
- [[3]](https://huggingface.co/sentence-transformers/all-MiniLM-L12-v2) Model Card: sentence-transformers/all-MiniLM-L12-v2
- [[4]](https://github.com/makcedward/nlpaug) Github: Data augmentation for NLP
- [[5]](https://arxiv.org/abs/1908.10084) Nils Reimers, Iryna Gurevych - Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks - 2019
- [[6]](https://arxiv.org/abs/2010.08191) Kai Liu, Ruiyang Ren, Wayne Xin Zhao, et al. - RocketQA: An Optimized Training Approach to Dense Passage Retrieval for Open-Domain Question Answering - 2021
- [[7]](https://arxiv.org/abs/2010.08240) Nandan Thakur, Nils Reimers, Johannes Daxenberger, Iryna Gurevych - Augmented SBERT: Data Augmentation Method for Improving Bi-Encoders for Pairwise Sentence Scoring Tasks - 2020
- [[8]](https://arxiv.org/abs/2103.00020) Alec Radford, Jong Wook Kim, et. al - Learning Transferable Visual Models From Natural Language Supervision - 2021
- [[9]](https://arxiv.org/abs/2004.04906) Vladimir Karpukhin, Barlas Oğuz, et al. - Dense Passage Retrieval for Open Domain Question Answering - 2020
- [[10]](https://arxiv.org/abs/2210.07316) Niklas Muennighoff, Nouamane Tazi, Loïc Magne, Nils Reimers - MTEB: Massive Text Embedding Benchmark - 2022
- [[11]](https://arxiv.org/abs/1807.03748) Aaron van den Oord, Yazhe Li, Oriol Vinyals - Representation Learning with Contrastive Predictive Coding - 2018
- [[12]](https://arxiv.org/abs/2002.05709) Ting Chen, Simon Kornblith, Mohammad Norouzi, Geoffrey Hinton - A Simple Framework for Contrastive Learning of Visual Representations - 2020
- [[13]](https://arxiv.org/abs/2212.03533) Liang Wang, Nan Yang, Xiaolong Huang, Binxing Jiao, Linjun Yang, Daxin Jiang, Rangan Majumder, Furu Wei - Text Embeddings by Weakly-Supervised Contrastive Pre-training - 2022
- [[14]](https://arxiv.org/abs/2104.02057) Xinlei Chen, Saining Xie, et al. - An Empirical Study of Training Self-Supervised Vision Transformers - 2021
- [[15]](https://arxiv.org/abs/2006.11632) Jui-Ting Huang, Ashish Sharma, Shuying Sun, Li Xia, David Zhang, Philip Pronin, Janani Padmanabhan, Giuseppe Ottaviano, Linjun Yang - Embedding-based Retrieval in Facebook Search - 2020