In [1]:
import torch

In [2]:
X = torch.rand(3, 10)
Y = torch.rand(3, 10)
tau = 1

## Using linear algebra to compute the NT-Xent loss

In [3]:
X_norm = X / (torch.norm(X,dim=1).reshape(-1,1))
Y_norm = Y / (torch.norm(Y,dim=1).reshape(-1,1))
XY_norm = torch.cat([X_norm, Y_norm],dim=0)
YX_norm = torch.cat([Y_norm, X_norm],dim=0)
sim_mat_1 = XY_norm @ XY_norm.T

In [4]:
sim_mat_1

tensor([[1.0000, 0.7281, 0.8607, 0.7395, 0.8699, 0.9070],
        [0.7281, 1.0000, 0.6962, 0.6181, 0.7495, 0.8019],
        [0.8607, 0.6962, 1.0000, 0.8281, 0.7893, 0.7802],
        [0.7395, 0.6181, 0.8281, 1.0000, 0.7616, 0.7524],
        [0.8699, 0.7495, 0.7893, 0.7616, 1.0000, 0.9197],
        [0.9070, 0.8019, 0.7802, 0.7524, 0.9197, 1.0000]])

In [5]:
exp_mat = torch.exp(sim_mat_1 / tau)
exp_mat_row_sum = torch.sum(exp_mat, dim=1)
exp_mat_diag = torch.diag(exp_mat)
denominator = exp_mat_row_sum - exp_mat_diag

In [6]:
sim_mat_2 = XY_norm @ YX_norm.T
numerator = torch.exp(torch.diag(sim_mat_2) / tau)

In [7]:
print(numerator)

tensor([2.0948, 2.1159, 2.1820, 2.0948, 2.1159, 2.1820])


In [8]:
nt_xent = - torch.log(numerator / denominator)

In [9]:
print(nt_xent)

tensor([1.6937, 1.5806, 1.6216, 1.6122, 1.6801, 1.6638])


In [10]:
loss = torch.mean(nt_xent)

In [11]:
loss

tensor(1.6420)

## Check: Compute $\ell$[0] without linear algebra

In [12]:
def f(x, tau):
    return torch.exp(x / tau)
def cos(x, y):
    return torch.nn.CosineSimilarity()(x.view(1, -1), y.view(1, -1))

In [13]:
n = cos(X[0], Y[0])
n = f(n, 1)
n

tensor([2.0948])

In [14]:
d = 0
d += f(cos(X_norm[0], Y_norm[0]), 1)
for i in range(2):
    d += f(cos(X_norm[0], X_norm[i+1]), 1)
    d += f(cos(X_norm[0], Y_norm[i+1]), 1)
d

tensor([11.3943])

In [15]:
- torch.log(n / d)

tensor([1.6937])

**Note that this number is the same as the first term of `nt_xent`**

## NT-Xent loss function

In [16]:
def nt_xent_loss(X, Y):
    '''
    Input: X, Y are matrices with shape (batch, hid_size)
    Return: NT-Xent loss
    '''
    X = X.view(X.shape[0], -1)
    Y = Y.view(Y.shape[0], -1)
    X_norm = X / (torch.norm(X,dim=1).reshape(-1,1))
    Y_norm = Y / (torch.norm(Y,dim=1).reshape(-1,1))
    XY_norm = torch.cat([X_norm, Y_norm],dim=0)
    YX_norm = torch.cat([Y_norm, X_norm],dim=0)
    
    sim_mat_1 = XY_norm @ XY_norm.T
    exp_mat = torch.exp(sim_mat_1 / tau)
    exp_mat_row_sum = torch.sum(exp_mat, dim=1)
    exp_mat_diag = torch.diag(exp_mat)
    denominator = exp_mat_row_sum - exp_mat_diag
    
    sim_mat_2 = XY_norm @ YX_norm.T
    numerator = torch.exp(torch.diag(sim_mat_2) / tau)
    
    nt_xent = - torch.log(numerator / denominator)
    return torch.mean(nt_xent)

In [17]:
nt_xent_loss(X, Y)

tensor(1.6420)