In [54]:
import sys
sys.path.append('../')

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision import datasets
from torch.utils.data import SubsetRandomSampler
import numpy as np
import random

import numpy as np
import pandas as pd
import os
import glob
import tqdm
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import seaborn as sns
import sklearn

eps = np.finfo(float).eps

plt.rcParams['figure.figsize'] = 10, 10
%matplotlib inline

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [55]:
batch_size = 4
feat_dim = 512
tau = 1
x_i = torch.randn(batch_size, feat_dim)
x_j = x_i + 0.05 * torch.randn(batch_size, feat_dim)
x = torch.cat((x_i, x_j), dim=0)

In [56]:
sim_mat_nom = torch.mm(x, x.T)
sim_mat_denom = torch.mm(torch.norm(x, dim=1).unsqueeze(1), torch.norm(x, dim=1).unsqueeze(1).T)

sim_mat = sim_mat_nom / sim_mat_denom.clamp(min=1e-16)
sim_mat = torch.exp(sim_mat / tau)

print(sim_mat.size())

torch.Size([8, 8])


In [57]:
print(sim_mat)

tensor([[2.7183, 1.0419, 0.9345, 0.9992, 2.7151, 1.0429, 0.9367, 0.9996],
        [1.0419, 2.7183, 0.9406, 0.9986, 1.0444, 2.7150, 0.9391, 0.9958],
        [0.9345, 0.9406, 2.7183, 1.0348, 0.9302, 0.9403, 2.7151, 1.0350],
        [0.9992, 0.9986, 1.0348, 2.7183, 1.0002, 0.9953, 1.0388, 2.7145],
        [2.7151, 1.0444, 0.9302, 1.0002, 2.7183, 1.0454, 0.9324, 1.0004],
        [1.0429, 2.7150, 0.9403, 0.9953, 1.0454, 2.7183, 0.9388, 0.9923],
        [0.9367, 0.9391, 2.7151, 1.0388, 0.9324, 0.9388, 2.7183, 1.0392],
        [0.9996, 0.9958, 1.0350, 2.7145, 1.0004, 0.9923, 1.0392, 2.7183]])


In [58]:
i_ind = torch.arange(0, batch_size)
j_ind = torch.arange(batch_size, batch_size * 2)

diag_ind = torch.eye(batch_size * 2).bool()
sim_mat = sim_mat.masked_fill_(diag_ind, 0)

left_score = []
right_score = []
losses = []
for i in range(4):
    j = i + 4
    print('{}-{}'.format(i,j))
    loss = 2 * sim_mat[i][j] / torch.sum(sim_mat[i, :])
    losses.append(loss)

torch.mean(torch.tensor(losses))

0-4
1-5
2-6
3-7


tensor(0.6268)

In [84]:
class contrastive_loss(nn.Module):
    def __init__(self, tau=1):
        super(contrastive_loss, self).__init__()
        self.tau = tau

    def forward(self, x):
        b_sz = x.size(0) // 2

        sim_mat_nom = torch.mm(x, x.T)
        sim_mat_denom = torch.mm(torch.norm(x, dim=1).unsqueeze(1), torch.norm(x, dim=1).unsqueeze(1).T)

        sim_mat = sim_mat_nom / sim_mat_denom.clamp(min=1e-16)
        sim_mat = torch.exp(sim_mat / self.tau)

        # getting rid of diag
        diag_ind = torch.eye(b_sz * 2).bool()
        sim_mat = sim_mat.masked_fill_(diag_ind, 0)

        loss_nom = torch.zeros(b_sz)

        for i in range(b_sz):
            loss_nom[i] = sim_mat[i][i + b_sz]
        loss_nom = torch.cat((loss_nom, loss_nom), dim=0)

        loss = torch.mean(-torch.log(loss_nom / torch.sum(sim_mat, dim=-1)))

        return loss

In [85]:
class contrastive_loss2(nn.Module):
    def __init__(self, tau=1):
        super(contrastive_loss2, self).__init__()
        self.tau = tau
        
    def forward(self, xi, xj):
        
        x = torch.cat((xi, xj), dim=0)
        
        sim_mat_nom = torch.mm(x, x.T)
        sim_mat_denom = torch.mm(torch.norm(x, dim=1).unsqueeze(1), torch.norm(x, dim=1).unsqueeze(1).T)
        sim_mat = sim_mat_nom / sim_mat_denom.clamp(min=1e-16)
        sim_mat = torch.exp(sim_mat / self.tau)

        # getting rid of diag
        diag_ind = torch.eye(xi.size(0) * 2).bool()
        sim_mat = sim_mat.masked_fill_(diag_ind, 0)
        
        # top
        sim_mat_denom = torch.norm(xi, dim=1) * torch.norm(xj, dim=1)
        sim_match = torch.exp(torch.sum(xi * xj, dim=-1) / sim_mat_denom / self.tau) 
        sim_match = torch.cat((sim_match, sim_match), dim=0)
        
        loss = torch.mean(-torch.log(sim_match / torch.sum(sim_mat, dim=-1)))
        
        return loss
        

In [86]:
batch_size = 4
feat_dim = 512
tau = 1
xi = torch.randn(batch_size, feat_dim)
xj = xi + (0.05 * torch.randn(batch_size, feat_dim))

x = torch.cat((xi, xj), dim=0)

loss_func1 = contrastive_loss()
loss_func2 = contrastive_loss2()

print(loss_func1(x))

print(loss_func2(xi, xj))

tensor(1.1710)
tensor(1.1710)
