# Normalized Temperature-scaled Cross Entropy Loss (NT-Xent)

https://paperswithcode.com/method/nt-xent

### The setup

This loss function is applied to N distributions (N being the batch size).

The 1st and 2nd distributions in the set are considered to be similar to each other (as used in SimCLR for self-supervised learning). The rest are to be considered to be dissimilar to the 1st one.

### Simple interpretation

NT-Xent is just a very fancy way to say the following:

1. An all-pairs Cosine Similarity score is computed for each of the N values.
2. Comparison results between the same value are discarded (since a distribution is perfectly similar to itself)
3. Each value (cosine similarity) is scaled by a temperature `T` (which is a hyper-parameter)
4. Cross Entropy Loss is applied to each row of the resulting matrix above
5. Typically, the mean of these losses (one per row) is used for backpropagation

In [143]:
import torch
from torch import nn
from torch.nn import functional as F
_ = torch.manual_seed(21)

In this example, we assume that the input (and output) batch size is 8.

In [144]:
eye = torch.eye(8)
eye, ~eye.bool()

(tensor([[1., 0., 0., 0., 0., 0., 0., 0.],
         [0., 1., 0., 0., 0., 0., 0., 0.],
         [0., 0., 1., 0., 0., 0., 0., 0.],
         [0., 0., 0., 1., 0., 0., 0., 0.],
         [0., 0., 0., 0., 1., 0., 0., 0.],
         [0., 0., 0., 0., 0., 1., 0., 0.],
         [0., 0., 0., 0., 0., 0., 1., 0.],
         [0., 0., 0., 0., 0., 0., 0., 1.]]),
 tensor([[False,  True,  True,  True,  True,  True,  True,  True],
         [ True, False,  True,  True,  True,  True,  True,  True],
         [ True,  True, False,  True,  True,  True,  True,  True],
         [ True,  True,  True, False,  True,  True,  True,  True],
         [ True,  True,  True,  True, False,  True,  True,  True],
         [ True,  True,  True,  True,  True, False,  True,  True],
         [ True,  True,  True,  True,  True,  True, False,  True],
         [ True,  True,  True,  True,  True,  True,  True, False]]))

Let's assume that the model's output is an `(8, 2)` tensor, with each element being a 2d vector in space.

In [145]:
x = torch.randn(8, 2)
x

tensor([[-0.2386, -1.0934],
        [ 0.1558,  0.1750],
        [-0.9526, -0.5442],
        [ 1.1985,  0.9604],
        [-1.1074, -0.8403],
        [-0.0020,  0.2240],
        [ 0.8766, -0.5379],
        [-0.2994,  0.9785]])

Let's compute the cosine similarity between every pair of vectors. Since we have 8 vectors, we expect `8x8=64` similarity scores.

In [146]:
xcs = F.cosine_similarity(x[None,:,:], x[:,None,:], dim=-1)
xcs

tensor([[ 1.0000, -0.8714,  0.6698, -0.7773,  0.7604, -0.9751,  0.3293, -0.8719],
        [-0.8714,  1.0000, -0.9479,  0.9860, -0.9812,  0.7408,  0.1763,  0.5195],
        [ 0.6698, -0.9479,  1.0000, -0.9878,  0.9916, -0.4883, -0.4806, -0.2203],
        [-0.7773,  0.9860, -0.9878,  1.0000, -0.9997,  0.6183,  0.3381,  0.3696],
        [ 0.7604, -0.9812,  0.9916, -0.9997,  1.0000, -0.5973, -0.3628, -0.3449],
        [-0.9751,  0.7408, -0.4883,  0.6183, -0.5973,  1.0000, -0.5306,  0.9588],
        [ 0.3293,  0.1763, -0.4806,  0.3381, -0.3628, -0.5306,  1.0000, -0.7495],
        [-0.8719,  0.5195, -0.2203,  0.3696, -0.3449,  0.9588, -0.7495,  1.0000]])

Let's replace all the diagonal elements (cosine similarity of an element with itself) with `-inf` so that when we compute the softmax later, it will show up as `0.0`.

In [147]:
y = xcs.clone()
y[eye.bool()] = float("-inf")
y

tensor([[   -inf, -0.8714,  0.6698, -0.7773,  0.7604, -0.9751,  0.3293, -0.8719],
        [-0.8714,    -inf, -0.9479,  0.9860, -0.9812,  0.7408,  0.1763,  0.5195],
        [ 0.6698, -0.9479,    -inf, -0.9878,  0.9916, -0.4883, -0.4806, -0.2203],
        [-0.7773,  0.9860, -0.9878,    -inf, -0.9997,  0.6183,  0.3381,  0.3696],
        [ 0.7604, -0.9812,  0.9916, -0.9997,    -inf, -0.5973, -0.3628, -0.3449],
        [-0.9751,  0.7408, -0.4883,  0.6183, -0.5973,    -inf, -0.5306,  0.9588],
        [ 0.3293,  0.1763, -0.4806,  0.3381, -0.3628, -0.5306,    -inf, -0.7495],
        [-0.8719,  0.5195, -0.2203,  0.3696, -0.3449,  0.9588, -0.7495,    -inf]])

Before we jump into using Cross Entropy loss to compute the loss, let's visualize the ground truth labels we wish to assign to every element in the cosine similarity matrix above.

The matrix `target` below will contain the ground-truth labels we will feed in to the `cross_entropy` function to compute the cross entropy loss. However, when visualizing it, we wish to view it as a 2d matrix, and not a 1d matrix of labels per element.

When interpreting the matrix `y` above, we need to keep in mind that the first row represents the Cosine Similarity between the element at index 0 and at every index between 0 and 7. The second row represents the Cosine Similarity between the element at index 1 and at every index between 0 and 7, and so on.

The ground truth labels can be visualized as shown below:

In [148]:
target = torch.arange(8)
target[0::2] += 1
target[1::2] -= 1
target

tensor([1, 0, 3, 2, 5, 4, 7, 6])

The tensor `index` is going to be a column tensor (1 row, 8 columns) with the same elements as `target` that will be used to scatter the ones into the zeros 2d tensor `ground_truth_labels`.

In [149]:
index = target.reshape(8, 1).long()
index

tensor([[1],
        [0],
        [3],
        [2],
        [5],
        [4],
        [7],
        [6]])

In [150]:
ground_truth_labels = torch.zeros(8, 8).long()
src = torch.ones(8, 8).long()
ground_truth_labels = torch.scatter(ground_truth_labels, 1, index, src)
ground_truth_labels

tensor([[0, 1, 0, 0, 0, 0, 0, 0],
        [1, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 1, 0, 0, 0, 0],
        [0, 0, 1, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 1, 0, 0],
        [0, 0, 0, 0, 1, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 1],
        [0, 0, 0, 0, 0, 0, 1, 0]])

What follows is a more detailed explanation of why we create `target` the way we do above.

If we wish to compute the row (or column) wise cross-entropy-loss, we observe that element at index 1 in the first row (index 0) is expected to have label 1 (for similar) and the rest of the values are dis-similar from the 0th indexed result. This is because the elements at index 0 and 1 are expected to be augmentations of the same image when using SimCLR.

For the 2nd row (index 1), the element at index 0 should have the label 1 (mirror-image case of the one above), and the rest are dissimilar to the element at index 1.

Similarly, we observe that the elements 3, 2, 5, 4, 7, 6 will have a label `1`.

The tensor `y` contains the ground-truth labels for feeding into the cross entropy loss function.

Here's what the documentation says about the [cross_entropy](https://pytorch.org/docs/stable/generated/torch.nn.functional.cosine_similarity.html) function:

> This criterion computes the cross entropy loss between input logits and target.

In [151]:
F.cross_entropy(
    y,
    target,
    reduction="mean",
)

tensor(2.8555)

## Final implementation

Now that we have seen how to build up the NT-Xent loss, we can condense the complete implementation (including the computation of the all-pairs Cosine Similarity) into this one simple function below.

In [152]:
def nt_xent_loss(x, temperature):
    assert len(x.size()) == 2
    
    # Cosine similarity
    xcs = F.cosine_similarity(x[None,:,:], x[:,None,:], dim=-1)
    xcs[torch.eye(x.size(0)).bool()] = float("-inf")

    # Ground truth labels
    target = torch.arange(8)
    target[0::2] += 1
    target[1::2] -= 1
    
    # Standard cross entropy loss
    return F.cross_entropy(xcs / temperature, target, reduction="mean")

for t in (0.01, 0.1, 1.0, 10.0, 20.0):
    print(f"Temperature: {t:5.2f}, Loss: {nt_xent_loss(x, temperature=t)}")

Temperature:  0.01, Loss: 167.33396911621094
Temperature:  0.10, Loss: 16.916988372802734
Temperature:  1.00, Loss: 2.8555006980895996
Temperature: 10.00, Loss: 2.0152008533477783
Temperature: 20.00, Loss: 1.979940414428711
