In [1]:
import torch
from torch import nn
from torch.nn import functional as F

In [2]:
batch_size = 6
hidden_size = 5

In [3]:
eye = torch.eye(batch_size)
eye

tensor([[1., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0.],
        [0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 1., 0., 0.],
        [0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 1.]])

In [4]:
# Model output
torch.manual_seed(42)
x = torch.randn(batch_size, hidden_size)
x

tensor([[ 1.9269,  1.4873,  0.9007, -2.1055,  0.6784],
        [-1.2345, -0.0431, -1.6047, -0.7521,  1.6487],
        [-0.3925, -1.4036, -0.7279, -0.5594, -2.3169],
        [-0.2168, -1.3847, -0.8712, -0.2234,  1.7174],
        [ 0.3189, -0.4245, -0.8286,  0.3309, -1.5576],
        [ 0.9956, -0.8798, -0.6011, -1.2742,  2.1228]])

In [5]:
# compute sim between every pair

similarity_matrix = F.cosine_similarity(
    x.reshape(1, batch_size, hidden_size), 
    x.reshape(batch_size, 1, hidden_size), 
    dim=-1
)
similarity_matrix

tensor([[ 1.0000, -0.1280, -0.3954, -0.1994, -0.3942,  0.4277],
        [-0.1280,  1.0000, -0.2149,  0.7268, -0.3662,  0.5419],
        [-0.3954, -0.2149,  1.0000, -0.1725,  0.8322, -0.3525],
        [-0.1994,  0.7268, -0.1725,  1.0000, -0.3368,  0.7938],
        [-0.3942, -0.3662,  0.8322, -0.3368,  1.0000, -0.4720],
        [ 0.4277,  0.5419, -0.3525,  0.7938, -0.4720,  1.0000]])

In [6]:
# discard main diagonal

similarity_matrix[eye.bool()] = float("-inf")
similarity_matrix

tensor([[   -inf, -0.1280, -0.3954, -0.1994, -0.3942,  0.4277],
        [-0.1280,    -inf, -0.2149,  0.7268, -0.3662,  0.5419],
        [-0.3954, -0.2149,    -inf, -0.1725,  0.8322, -0.3525],
        [-0.1994,  0.7268, -0.1725,    -inf, -0.3368,  0.7938],
        [-0.3942, -0.3662,  0.8322, -0.3368,    -inf, -0.4720],
        [ 0.4277,  0.5419, -0.3525,  0.7938, -0.4720,    -inf]])

In [7]:
# target matrix

target_matrix = torch.arange(batch_size)
target_matrix[0::2] += 1
target_matrix[1::2] -= 1
target_matrix

tensor([1, 0, 3, 2, 5, 4])

In [8]:
index = target_matrix.reshape(batch_size, 1).long()
index

tensor([[1],
        [0],
        [3],
        [2],
        [5],
        [4]])

In [9]:
# create labels matrix

zeros = torch.zeros(batch_size, batch_size).long()
ones = torch.ones(batch_size, batch_size).long()


ground_truth_labels = torch.scatter(zeros, 1, index, ones)
ground_truth_labels

tensor([[0, 1, 0, 0, 0, 0],
        [1, 0, 0, 0, 0, 0],
        [0, 0, 0, 1, 0, 0],
        [0, 0, 1, 0, 0, 0],
        [0, 0, 0, 0, 0, 1],
        [0, 0, 0, 0, 1, 0]])

In [10]:
F.cross_entropy(similarity_matrix, target_matrix, reduction="mean")

tensor(1.9966)

In [14]:
# implementation

def nt_xent_loss(model_output, temperature):
    """Calculate NT-Xent loss.

    Args:
        model_output (Tensor): Model output
        temperature (float): Loss temperature

    Returns:
        Tensor: NT-Xent loss
    """
    batch_size = model_output.shape[0]
    hidden_size = model_output.shape[1]
    
    # Cosine similarity
    similarity_matrix = F.cosine_similarity(
        x.reshape(1, batch_size, hidden_size),
        x.reshape(batch_size, 1, hidden_size),
        dim=-1
    )

    # Discard main diagonal
    similarity_matrix[torch.eye(batch_size).bool()] = float("-inf")

    # Labels
    labels = torch.arange(batch_size)
    labels[0::2] += 1
    labels[1::2] -= 1
    
    # Compute cross entropy loss
    return F.cross_entropy(similarity_matrix / temperature, labels, reduction="mean")

torch.manual_seed(42)
batch = torch.randn(batch_size, hidden_size)

for t in (0.01, 0.1, 1.0, 10.0):
    print(f"Temperature: {t:.2f}, Loss: {nt_xent_loss(batch, temperature=t)}")

Temperature: 0.01, Loss: 99.19322967529297
Temperature: 0.10, Loss: 10.030501365661621
Temperature: 1.00, Loss: 1.9965673685073853
Temperature: 10.00, Loss: 1.6381586790084839
