In [1]:
import torch
from torch import nn
from torch.nn import functional as F

In [2]:
batch_size = 32 # bigger batch size for local computer causes badalloc (on A100 can try _ batch_size)
seq_len = 512
hidden_size = 768

In [3]:
eye = torch.eye(batch_size)
eye

tensor([[1., 0., 0.,  ..., 0., 0., 0.],
        [0., 1., 0.,  ..., 0., 0., 0.],
        [0., 0., 1.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 1., 0., 0.],
        [0., 0., 0.,  ..., 0., 1., 0.],
        [0., 0., 0.,  ..., 0., 0., 1.]])

In [4]:
# Model output
torch.manual_seed(42)
x = torch.randn(batch_size, seq_len, hidden_size)
x.shape

torch.Size([32, 512, 768])

In [5]:
# flatten seq_len and hidden_size into one dimension
x = x.view(x.size()[0], -1)
x.shape

torch.Size([32, 393216])

In [6]:
x.to(device='cuda')

tensor([[ 1.9269,  1.4873,  0.9007,  ...,  0.4437, -0.3335,  1.3220],
        [ 0.6839,  0.7625, -0.8857,  ..., -0.2343, -0.9000, -0.8702],
        [ 1.2346,  0.6121, -0.3751,  ..., -1.9618, -1.1964,  0.6218],
        ...,
        [ 0.7162,  0.8673, -0.8819,  ...,  0.1874, -0.5655, -0.2217],
        [ 0.6127,  0.3267, -0.4352,  ..., -0.6766,  2.0321,  0.1363],
        [-1.4985, -2.6349,  0.4779,  ..., -0.1444, -0.3367, -0.0716]],
       device='cuda:0')

In [7]:
def get_tensor_size_in_gb(t):
    return (t.element_size() * t.nelement()) / 1_000_000_000

In [8]:
# compute sim between every pair
# TODO: not efficient by memory

similarity_matrix = F.cosine_similarity(
    x.reshape(1, x.size()[0], x.size()[1]), 
    x.reshape(x.size()[0], 1, x.size()[1]), 
    dim=-1
)
similarity_matrix

tensor([[ 1.0000e+00,  2.1379e-03, -1.6717e-03,  ..., -1.2817e-03,
          1.4366e-03,  2.3230e-03],
        [ 2.1379e-03,  1.0000e+00,  1.5649e-04,  ...,  1.3443e-03,
          1.0031e-03, -1.2600e-03],
        [-1.6717e-03,  1.5649e-04,  1.0000e+00,  ...,  8.2536e-04,
          2.9734e-04, -6.1174e-04],
        ...,
        [-1.2817e-03,  1.3443e-03,  8.2536e-04,  ...,  1.0000e+00,
         -8.5864e-04, -1.4561e-03],
        [ 1.4366e-03,  1.0031e-03,  2.9734e-04,  ..., -8.5864e-04,
          1.0000e+00,  1.6072e-03],
        [ 2.3230e-03, -1.2600e-03, -6.1174e-04,  ..., -1.4561e-03,
          1.6072e-03,  1.0000e+00]])

In [9]:
# discard main diagonal

similarity_matrix[eye.bool()] = float("-inf")
similarity_matrix

tensor([[   -inf,  0.0021, -0.0017,  ..., -0.0013,  0.0014,  0.0023],
        [ 0.0021,    -inf,  0.0002,  ...,  0.0013,  0.0010, -0.0013],
        [-0.0017,  0.0002,    -inf,  ...,  0.0008,  0.0003, -0.0006],
        ...,
        [-0.0013,  0.0013,  0.0008,  ...,    -inf, -0.0009, -0.0015],
        [ 0.0014,  0.0010,  0.0003,  ..., -0.0009,    -inf,  0.0016],
        [ 0.0023, -0.0013, -0.0006,  ..., -0.0015,  0.0016,    -inf]])

In [10]:
# target matrix

target_matrix = torch.arange(batch_size)
target_matrix[0::2] += 1
target_matrix[1::2] -= 1
target_matrix

tensor([ 1,  0,  3,  2,  5,  4,  7,  6,  9,  8, 11, 10, 13, 12, 15, 14, 17, 16,
        19, 18, 21, 20, 23, 22, 25, 24, 27, 26, 29, 28, 31, 30])

In [11]:
index = target_matrix.reshape(x.size()[0], 1).long()
index

tensor([[ 1],
        [ 0],
        [ 3],
        [ 2],
        [ 5],
        [ 4],
        [ 7],
        [ 6],
        [ 9],
        [ 8],
        [11],
        [10],
        [13],
        [12],
        [15],
        [14],
        [17],
        [16],
        [19],
        [18],
        [21],
        [20],
        [23],
        [22],
        [25],
        [24],
        [27],
        [26],
        [29],
        [28],
        [31],
        [30]])

In [12]:
# create labels matrix

zeros = torch.zeros(x.size()[0], x.size()[1]).long()
ones = torch.ones(x.size()[0], x.size()[1]).long()


ground_truth_labels = torch.scatter(zeros, 1, index, ones)
ground_truth_labels

tensor([[0, 1, 0,  ..., 0, 0, 0],
        [1, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        ...,
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0]])

In [13]:
F.cross_entropy(similarity_matrix, target_matrix, reduction="mean")

tensor(3.4339)

In [14]:
# implementation

def nt_xent_loss(model_output, temperature):
    """Calculate NT-Xent loss.

    Args:
        model_output (Tensor): Model output
        temperature (float): Loss temperature

    Returns:
        Tensor: NT-Xent loss
    """
    batch_size = model_output.shape[0]
    hidden_size = model_output.shape[1]
    
    # Cosine similarity
    similarity_matrix = F.cosine_similarity(
        x.reshape(1, x.size()[0], x.size()[1]),
        x.reshape(x.size()[0], 1, x.size()[1]),
        dim=-1
    )

    # Discard main diagonal
    similarity_matrix[torch.eye(batch_size).bool()] = float("-inf")

    # Labels
    labels = torch.arange(batch_size)
    labels[0::2] += 1
    labels[1::2] -= 1
    
    # Compute cross entropy loss
    return F.cross_entropy(similarity_matrix / temperature, labels, reduction="mean")

torch.manual_seed(42)
batch = torch.randn(batch_size, hidden_size)

for t in (0.01, 0.1, 1.0, 10.0):
    print(f"Temperature: {t:.2f}, Loss: {nt_xent_loss(batch, temperature=t)}")

Temperature: 0.01, Loss: 3.4320895671844482
Temperature: 0.10, Loss: 3.4327821731567383
Temperature: 1.00, Loss: 3.433856725692749
Temperature: 10.00, Loss: 3.433974266052246
