# Normalized Temperature-scaled Cross Entropy Loss (NT-Xent)

Links:

* https://paperswithcode.com/method/nt-xent
* https://github.com/KevinMusgrave/pytorch-metric-learning/issues/6

### The setup

This loss function is applied to N feature vectors (one feature vector computed for each input). N here being the batch size.

The $1^{st}$ and $2^{nd}$ feature vectors in the set are considered to be similar to each other (as used in SimCLR for self-supervised learning). The rest are considered to be dissimilar to the $1^{st}$ one. Similarly, the $3^{rd}$ and $4^{th}$ feature vectors are considered to be similar to each other and the rest are considered to be dissimilar to the $3^{rd}$ one.

### Simple interpretation

NT-Xent is just a very fancy way to say the following:

1. The all-pairs Cosine Similarity score is computed for each of the N vectors produced by the SimCLR model.
2. Comparison results between the same value are discarded (since a distribution is perfectly similar to itself and can't possibly allow the model to learn anything useful)
3. Each value (cosine similarity) is scaled by a temperature $\tau$ (which is a hyper-parameter)
4. Cross Entropy Loss is applied to each row of the resulting matrix above
5. Typically, the mean of these losses (one loss per element in a batch) is used for backpropagation

### What this notebook will show

The first part will show an intuitive and efficient implementation of the NT-Xent loss in PyTorch for single-label (only 1 pair can be positive) contrastive learning tasks.

The second part will show an intuitive and efficient implementation of the NT-BXent (binary cross entropy) loss (defined below) in PyTorch for multi-label (multiple pairs can be positive) contrastive learning tasks.

In [90]:
import torch
from torch import nn
from torch.nn import functional as F
_ = torch.manual_seed(21)

In this example, we assume that the input (and output) batch size is 8.

In [91]:
eye = torch.eye(8)
eye, ~eye.bool()

(tensor([[1., 0., 0., 0., 0., 0., 0., 0.],
         [0., 1., 0., 0., 0., 0., 0., 0.],
         [0., 0., 1., 0., 0., 0., 0., 0.],
         [0., 0., 0., 1., 0., 0., 0., 0.],
         [0., 0., 0., 0., 1., 0., 0., 0.],
         [0., 0., 0., 0., 0., 1., 0., 0.],
         [0., 0., 0., 0., 0., 0., 1., 0.],
         [0., 0., 0., 0., 0., 0., 0., 1.]]),
 tensor([[False,  True,  True,  True,  True,  True,  True,  True],
         [ True, False,  True,  True,  True,  True,  True,  True],
         [ True,  True, False,  True,  True,  True,  True,  True],
         [ True,  True,  True, False,  True,  True,  True,  True],
         [ True,  True,  True,  True, False,  True,  True,  True],
         [ True,  True,  True,  True,  True, False,  True,  True],
         [ True,  True,  True,  True,  True,  True, False,  True],
         [ True,  True,  True,  True,  True,  True,  True, False]]))

Let's assume that the model's output is an `(8, 2)` tensor, with each element being a 2d vector in space.

In [92]:
x = torch.randn(8, 2)
x

tensor([[-0.2386, -1.0934],
        [ 0.1558,  0.1750],
        [-0.9526, -0.5442],
        [ 1.1985,  0.9604],
        [-1.1074, -0.8403],
        [-0.0020,  0.2240],
        [ 0.8766, -0.5379],
        [-0.2994,  0.9785]])

Let's compute the cosine similarity between every pair of vectors. Since we have 8 vectors, we expect `8x8=64` similarity scores.

In [93]:
xcs = F.cosine_similarity(x[None,:,:], x[:,None,:], dim=-1)
xcs

tensor([[ 1.0000, -0.8714,  0.6698, -0.7773,  0.7604, -0.9751,  0.3293, -0.8719],
        [-0.8714,  1.0000, -0.9479,  0.9860, -0.9812,  0.7408,  0.1763,  0.5195],
        [ 0.6698, -0.9479,  1.0000, -0.9878,  0.9916, -0.4883, -0.4806, -0.2203],
        [-0.7773,  0.9860, -0.9878,  1.0000, -0.9997,  0.6183,  0.3381,  0.3696],
        [ 0.7604, -0.9812,  0.9916, -0.9997,  1.0000, -0.5973, -0.3628, -0.3449],
        [-0.9751,  0.7408, -0.4883,  0.6183, -0.5973,  1.0000, -0.5306,  0.9588],
        [ 0.3293,  0.1763, -0.4806,  0.3381, -0.3628, -0.5306,  1.0000, -0.7495],
        [-0.8719,  0.5195, -0.2203,  0.3696, -0.3449,  0.9588, -0.7495,  1.0000]])

Let's replace all the diagonal elements (cosine similarity of an element with itself) with `-inf` so that when we compute the softmax later, it will show up as `0.0`.

In [94]:
y = xcs.clone()
y[eye.bool()] = float("-inf")
y

tensor([[   -inf, -0.8714,  0.6698, -0.7773,  0.7604, -0.9751,  0.3293, -0.8719],
        [-0.8714,    -inf, -0.9479,  0.9860, -0.9812,  0.7408,  0.1763,  0.5195],
        [ 0.6698, -0.9479,    -inf, -0.9878,  0.9916, -0.4883, -0.4806, -0.2203],
        [-0.7773,  0.9860, -0.9878,    -inf, -0.9997,  0.6183,  0.3381,  0.3696],
        [ 0.7604, -0.9812,  0.9916, -0.9997,    -inf, -0.5973, -0.3628, -0.3449],
        [-0.9751,  0.7408, -0.4883,  0.6183, -0.5973,    -inf, -0.5306,  0.9588],
        [ 0.3293,  0.1763, -0.4806,  0.3381, -0.3628, -0.5306,    -inf, -0.7495],
        [-0.8719,  0.5195, -0.2203,  0.3696, -0.3449,  0.9588, -0.7495,    -inf]])

Before we jump into using Cross Entropy loss to compute the loss, let's visualize the ground truth labels we wish to assign to every element in the cosine similarity matrix above.

The matrix `target` below will contain the ground-truth labels we will feed in to the `cross_entropy` function to compute the cross entropy loss. However, when visualizing it, we wish to view it as a 2d matrix, and not a 1d matrix of labels per element.

When interpreting the matrix `y` above, we need to keep in mind that the first row represents the Cosine Similarity between the element at index 0 and at every index between 0 and 7. The second row represents the Cosine Similarity between the element at index 1 and at every index between 0 and 7, and so on.

The ground truth labels can be visualized as shown below:

In [95]:
target = torch.arange(8)
target[0::2] += 1
target[1::2] -= 1
target

tensor([1, 0, 3, 2, 5, 4, 7, 6])

The tensor `index` is going to be a column tensor (1 row, 8 columns) with the same elements as `target` that will be used to scatter the ones into the zeros 2d tensor `ground_truth_labels`.

In [96]:
index = target.reshape(8, 1).long()
index

tensor([[1],
        [0],
        [3],
        [2],
        [5],
        [4],
        [7],
        [6]])

In [97]:
ground_truth_labels = torch.zeros(8, 8).long()
src = torch.ones(8, 8).long()
ground_truth_labels = torch.scatter(ground_truth_labels, 1, index, src)
ground_truth_labels

tensor([[0, 1, 0, 0, 0, 0, 0, 0],
        [1, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 1, 0, 0, 0, 0],
        [0, 0, 1, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 1, 0, 0],
        [0, 0, 0, 0, 1, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 1],
        [0, 0, 0, 0, 0, 0, 1, 0]])

What follows is a more detailed explanation of why we create `target` the way we do above.

If we wish to compute the row (or column) wise cross-entropy-loss, we observe that element at index 1 in the first row (index 0) is expected to have label 1 (for similar) and the rest of the values are dis-similar from the 0th indexed result. This is because the elements at index 0 and 1 are expected to be augmentations of the same image when using SimCLR.

For the 2nd row (index 1), the element at index 0 should have the label 1 (mirror-image case of the one above), and the rest are dissimilar to the element at index 1.

Similarly, we observe that the elements 3, 2, 5, 4, 7, 6 will have a label `1`.

The tensor `y` contains the ground-truth labels for feeding into the cross entropy loss function.

The way that the cross entropy loss is used here is slightly different from how it is used in classification tasks. In classification tasks, the network's classification head is trained to learn one of N classes for a given input image. In the SimCLR algorithm, we aren't dealing with classes. We compute the all-pairs cosine similarity of each image with all the other images in the batch, and then compare if the resulting distribution of probabilities (of how probable it is for an image in the batch to be similar to the i'th image) is compare with the ground-truth distribution using cross entropy loss. We could have use KL divergence too. However, since Cross Entropy is the special case when the target probabilities are either 0 or 1, we just use cross entropy loss.

Here's what the documentation says about the [cross_entropy](https://pytorch.org/docs/stable/generated/torch.nn.functional.cosine_similarity.html) function:

> This criterion computes the cross entropy loss between input logits and target.

In [98]:
F.cross_entropy(y, target, reduction="mean")

tensor(2.8555)

## Final implementation of the NT-Xent Loss

Now that we have seen how to build up the NT-Xent loss, we can condense the complete implementation (including the computation of the all-pairs Cosine Similarity) into this one simple function below.

In [99]:
def nt_xent_loss(x, temperature):
    assert len(x.size()) == 2
    
    # Cosine similarity
    xcs = F.cosine_similarity(x[None,:,:], x[:,None,:], dim=-1)
    xcs[torch.eye(x.size(0)).bool()] = float("-inf")

    # Ground truth labels
    target = torch.arange(8)
    target[0::2] += 1
    target[1::2] -= 1
    
    # Standard cross entropy loss
    return F.cross_entropy(xcs / temperature, target, reduction="mean")

for t in (0.01, 0.1, 1.0, 10.0, 20.0):
    print(f"Temperature: {t:5.2f}, Loss: {nt_xent_loss(x, temperature=t)}")

Temperature:  0.01, Loss: 167.33396911621094
Temperature:  0.10, Loss: 16.916988372802734
Temperature:  1.00, Loss: 2.8555006980895996
Temperature: 10.00, Loss: 2.0152008533477783
Temperature: 20.00, Loss: 1.979940414428711


# Multi-label loss for contrastive learning (NT-BXent)

In the previous single-label setting, there's exactly 1 positive example in each row, with the rest being negative examples. In a multi-label setting, it's possible for more than one positive example in a given row, with the rest being negative examples. This can occur when we set up SimCLR so that each image is augmented N times (N > 2), and hence, it results in multiple (> 1) pairs of positive examples.

In the single-label case, we could use cross-entropy loss. However, in the multi-label setting, we can't use cross entropy loss since cross entropy is effectively KL divergence, which measures the difference in 2 distributions. We can't have a distribution with the total probability summing up to a value > 1.0 and we don't want the positive examples to have a probability significantly less than one, which would have to happen if want all the probabilities in a row to sum up to 1.0. Hence, we turn our attention to Binary Cross Entropy Loss, and treat each sample as independently being 0.0 (negative) or 1.0 (positive). This leads to a slightly different formulation in code.

For each pair $(i,j)$, we compute $s_{ij}$ as the cosine similarity between that pair of elements in the all-pairs cosine similarity matrix. The value $l_{ij}$ computes the loss for a specific element.

$$ l_{ij} = y_{ij}.\log{\sigma{ ( {s_{ij}} / \tau ) }} + (1 - y_{ij}).\log{\sigma((1 - s_{ij}) / \tau )}$$

We can then compute the mean or weighted mean across all the element-wise losses. The unweighted mean is.

$$ l_i = \frac{1}{N} \Sigma_{j=1}^N {l_{ij}} $$

However, there's a class imbalance problem with the formulation above. Suppose that in a set of N=100 samples, there are only 2 positive pairs and 98 negative pairs, then the loss from the negative pairs will overpower the loss from the positive pairs. Hence, we need to weight the losses from the positive and negative pairs accordingly. In the formula below, $1_{ij}^{pos}$ is $1$ if $ij$ is a positive pair, and $0$ otherwise. Similarly, $1_{ij}^{neg}$ is $1$ if $ij$ is a negative pair, and $0$ otherwise. $N_{pos}$ and $N_{neg}$ are the number of positive and negative pairs respectively.

$$ l_i = \frac{1}{N_{pos}} \Sigma_{j=1}^N {1_{ij}^{pos}l_{ij}} + \frac{1}{N_{neg}} \Sigma_{j=1}^N {1_{ij}^{neg}l_{ij}}$$

We can call this formulation the **NT-BXent** for **Normalized Temperature-scaled Binary Cross Entropy Loss**.

In [100]:
# The input to loss function will be 3 values:
#
# [1] target: Tensor with shape (Batch,NumFeatures). The feature vector for
# each input is computed and returned by the model in its forward() pass.
#
# [2] pos_indices: Tensor with shape (N,2), where N is the number of positive
# pairs. Each positive pair is a 0-indexed (Row,Col) position within the matrix
# x above.
#
# [3] temperature: float. The temperature to scale the raw logits by before
# applying the sigmoid activation function.
#
target = torch.zeros(8, 8)
pos_indices = torch.tensor([
    (0, 0), (0, 2), (0, 4),
    (1, 4), (1, 6), (1, 1),
    (2, 3),
    (3, 7),
    (4, 3),
    (7, 6),
])
# Add indexes of the principal diagonal as positive indexes. This will be useful
# since we will use the BCELoss in PyTorch, which will expect a value for the
# elements on the principal diagonal as well.
pos_indices = torch.cat([pos_indices, torch.arange(8).reshape(8, 1).expand(-1, 2)], dim=0)
print("\nPositive indexes list")
print(pos_indices)


# Set the values in the target vector to 1.
target[pos_indices[:,0], pos_indices[:,1]] = 1
print(f"\nGround Truth labels for positive and negative pairs for BCE Loss")
print(target)


Positive indexes list
tensor([[0, 0],
        [0, 2],
        [0, 4],
        [1, 4],
        [1, 6],
        [1, 1],
        [2, 3],
        [3, 7],
        [4, 3],
        [7, 6],
        [0, 0],
        [1, 1],
        [2, 2],
        [3, 3],
        [4, 4],
        [5, 5],
        [6, 6],
        [7, 7]])

Ground Truth labels for positive and negative pairs for BCE Loss
tensor([[1., 0., 1., 0., 1., 0., 0., 0.],
        [0., 1., 0., 0., 1., 0., 1., 0.],
        [0., 0., 1., 1., 0., 0., 0., 0.],
        [0., 0., 0., 1., 0., 0., 0., 1.],
        [0., 0., 0., 1., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 1., 0., 0.],
        [0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 1., 1.]])


## Documentation for binary cross entropy loss computation

* [torch.nn.functional.binary_cross_entropy_with_logits](https://pytorch.org/docs/stable/generated/torch.nn.functional.binary_cross_entropy_with_logits.html)
* [torch.nn.functional.binary_cross_entropy](https://pytorch.org/docs/stable/generated/torch.nn.functional.binary_cross_entropy.html)

The `_with_logits` version accepts raw logits, but we use the one which accepts values in the range 0.0 - 1.0 (i.e. we manually apply the Sigmoid activation function) since the one that accepts raw-logits uses the [log-sum-exp-trick](https://gregorygundersen.com/blog/2020/02/09/log-sum-exp/), and causes a `NaN` result when fed with $\infty$ or $-\infty$ as input. [This comment](https://github.com/pytorch/pytorch/issues/49844#issuecomment-1574693686) on a related github issue provides details of why the log-sum-exp-trick causes a problem with such values whereas running Sigmoid first avoids these issues.

In [101]:
y = xcs.clone()
# In our predicted all-pairs cosine similarity, we set the diagonal elements
# to infinity so that the sigmoid activation when applied to these elements
# will set them to 1.0. We're going to set the diagonal elements to infinity
# so that each element is considered to be similar to itself, and won't generate
# gradients due to this specific pair when the loss is propagated into the
# weights that contributed that specific loss value.
y[eye.bool()] = float("inf")
y, y.sigmoid()

(tensor([[    inf, -0.8714,  0.6698, -0.7773,  0.7604, -0.9751,  0.3293, -0.8719],
         [-0.8714,     inf, -0.9479,  0.9860, -0.9812,  0.7408,  0.1763,  0.5195],
         [ 0.6698, -0.9479,     inf, -0.9878,  0.9916, -0.4883, -0.4806, -0.2203],
         [-0.7773,  0.9860, -0.9878,     inf, -0.9997,  0.6183,  0.3381,  0.3696],
         [ 0.7604, -0.9812,  0.9916, -0.9997,     inf, -0.5973, -0.3628, -0.3449],
         [-0.9751,  0.7408, -0.4883,  0.6183, -0.5973,     inf, -0.5306,  0.9588],
         [ 0.3293,  0.1763, -0.4806,  0.3381, -0.3628, -0.5306,     inf, -0.7495],
         [-0.8719,  0.5195, -0.2203,  0.3696, -0.3449,  0.9588, -0.7495,     inf]]),
 tensor([[1.0000, 0.2950, 0.6615, 0.3149, 0.6814, 0.2739, 0.5816, 0.2949],
         [0.2950, 1.0000, 0.2793, 0.7283, 0.2726, 0.6772, 0.5440, 0.6270],
         [0.6615, 0.2793, 1.0000, 0.2713, 0.7294, 0.3803, 0.3821, 0.4451],
         [0.3149, 0.7283, 0.2713, 1.0000, 0.2690, 0.6498, 0.5837, 0.5914],
         [0.6814, 0.2726, 0.7294, 

In [102]:
temperature = 0.1
loss = F.binary_cross_entropy((y / temperature).sigmoid(), target, reduction="none")

target_pos = target.bool()
target_neg = ~target_pos

print("\nPositive pairs mask")
print(target_pos)

print("\nNegative pairs mask")
print(target_neg)

# loss_pos and loss_neg below contain non-zero values only for those elements
# that are positive pairs and negative pairs respectively.
loss_pos = torch.zeros(x.size(0), x.size(0)).masked_scatter(target_pos, loss[target_pos])
loss_neg = torch.zeros(x.size(0), x.size(0)).masked_scatter(target_neg, loss[target_neg])

print("\nPositive pairs only (loss)")
print(loss_pos)

print("\nNegative pairs only (loss)")
print(loss_neg)

# loss_pos and loss_neg now contain the sum of positive and negative pair losses
# as computed relative to the i'th input.
loss_pos = loss_pos.sum(dim=1)
loss_neg = loss_neg.sum(dim=1)

print("\nElement-wise positive pairs loss")
print(loss_pos)

print("\nElement-wise negative pairs loss")
print(loss_neg)

# num_pos and num_neg below contain the number of positive and negative pairs
# computed relative to the i'th input. In an actual setting, this number should
# be the same for every input element, but we let it vary here for maximum
# flexibility.
num_pos = target.sum(dim=1)
num_neg = target.size(0) - num_pos

print("\nNumber of elements similar to the i'th element")
print(num_pos)

print("\nNumber of elements dissimilar to the i'th element")
print(num_neg)

# Compute the weighted overall loss as seen in the formula above.
overall_loss = ((loss_pos / num_pos) + (loss_neg / num_neg)).mean()
print(f"\nOverall loss: {overall_loss.item()}")


Positive pairs mask
tensor([[ True, False,  True, False,  True, False, False, False],
        [False,  True, False, False,  True, False,  True, False],
        [False, False,  True,  True, False, False, False, False],
        [False, False, False,  True, False, False, False,  True],
        [False, False, False,  True,  True, False, False, False],
        [False, False, False, False, False,  True, False, False],
        [False, False, False, False, False, False,  True, False],
        [False, False, False, False, False, False,  True,  True]])

Negative pairs mask
tensor([[False,  True, False,  True, False,  True,  True,  True],
        [ True, False,  True,  True, False,  True, False,  True],
        [ True,  True, False, False,  True,  True,  True,  True],
        [ True,  True,  True, False,  True,  True,  True, False],
        [ True,  True,  True, False, False,  True,  True,  True],
        [ True,  True,  True,  True,  True, False,  True,  True],
        [ True,  True,  True,  Tr

## Final implementation of NT-BXent Loss

Now that we have seen how to build up the NT-BXent loss, we can condense the complete implementation (including the computation of the all-pairs Cosine Similarity) into this one simple function below.

In [103]:
def nt_bxent_loss(x, pos_indices, temperature):
    assert len(x.size()) == 2

    # Add indexes of the principal diagonal elements to pos_indices
    pos_indices = torch.cat([
        pos_indices,
        torch.arange(x.size(0)).reshape(x.size(0), 1).expand(-1, 2),
    ], dim=0)
    
    # Ground truth labels
    target = torch.zeros(x.size(0), x.size(0))
    target[pos_indices[:,0], pos_indices[:,1]] = 1.0

    # Cosine similarity
    xcs = F.cosine_similarity(x[None,:,:], x[:,None,:], dim=-1)
    # Set logit of diagonal element to "inf" signifying complete
    # correlation. sigmoid(inf) = 1.0 so this will work out nicely
    # when computing the Binary Cross Entropy Loss.
    xcs[torch.eye(x.size(0)).bool()] = float("inf")

    # Standard binary cross entropy loss. We use binary_cross_entropy() here and not
    # binary_cross_entropy_with_logits() because of https://github.com/pytorch/pytorch/issues/102894
    # The method *_with_logits() uses the log-sum-exp-trick, which causes inf and -inf values
    # to result in a NaN result.
    loss = F.binary_cross_entropy((xcs / temperature).sigmoid(), target, reduction="none")
    
    target_pos = target.bool()
    target_neg = ~target_pos
    
    loss_pos = torch.zeros(x.size(0), x.size(0)).masked_scatter(target_pos, loss[target_pos])
    loss_neg = torch.zeros(x.size(0), x.size(0)).masked_scatter(target_neg, loss[target_neg])
    loss_pos = loss_pos.sum(dim=1)
    loss_neg = loss_neg.sum(dim=1)
    num_pos = target.sum(dim=1)
    num_neg = x.size(0) - num_pos

    return ((loss_pos / num_pos) + (loss_neg / num_neg)).mean()

pos_indices = torch.tensor([
    (0, 0), (0, 2), (0, 4),
    (1, 4), (1, 6), (1, 1),
    (2, 3),
    (3, 7),
    (4, 3),
    (7, 6),
])
for t in (0.01, 0.1, 1.0, 10.0, 20.0):
    print(f"Temperature: {t:5.2f}, Loss: {nt_bxent_loss(x, pos_indices, temperature=t)}")

Temperature:  0.01, Loss: 62.898780822753906
Temperature:  0.10, Loss: 4.851151943206787
Temperature:  1.00, Loss: 1.0727109909057617
Temperature: 10.00, Loss: 0.9827173948287964
Temperature: 20.00, Loss: 0.982099175453186
