In [2]:
import os
import time
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, transforms, models

from sklearn.manifold import TSNE


In [20]:
def contrastive_loss(e1, e2, tau=1.0):

    e1 = F.normalize(e1, dim=1)
    e2 = F.normalize(e2, dim=1)

    e12 = torch.cat((e1, e2), dim=0)

    similarity_matrix = torch.mm(e12, e12.T)

    N = e1.shape[0]  # batch size
    loss = 0.0

    for i in range(N):
        loss += l_ij(i, i + N, similarity_matrix) + l_ij(
            i + N, i, similarity_matrix)

    return loss / (2 * N)


def l_ij(i, j, similarity_matrix, tau=1.0):

    sim_ij = similarity_matrix[i, j]

    numerator = torch.exp(sim_ij / tau)
    denominator = torch.exp(similarity_matrix[i] / tau).sum() - torch.exp(
        similarity_matrix[i, i] / tau)

    return -1 * torch.log(numerator / denominator)


In [87]:
%%time
I = torch.tensor([[1.0, 2.0], [3.0, -2.0], [1.0, 5.0]])
# I = torch.rand(10000, 100)
J = torch.tensor([[1.0, 0.75], [2.8, -1.75], [1.0, 4.7]])
# J = torch.rand(10000, 100)
loss = contrastive_loss(I, J)
loss

CPU times: user 1.81 ms, sys: 5.31 ms, total: 7.12 ms
Wall time: 7.46 ms


tensor(1.1327)

In [88]:
%%time
# Vectorized contrastive loss


def contrastive_loss_vectorized(e1, e2, tau=1.0):

    e1 = F.normalize(e1, dim=1)
    e2 = F.normalize(e2, dim=1)

    e12 = torch.cat((e1, e2), dim=0)
    n_sample = e12.shape[0]  # 2N samples

    similarity_matrix = torch.exp(torch.mm(e12, e12.T) / tau)

    # Negative similarity
    mask = ~torch.eye(n_sample).bool()
    negative_sim = similarity_matrix.masked_select(mask).view(n_sample,
                                                              -1).sum(dim=-1)
    # Positive similarity
    positive_sim = torch.exp(torch.sum(e1 * e2, dim=-1) / tau)
    positive_sim = torch.cat((positive_sim, positive_sim), dim=0)

    return -torch.log(positive_sim / negative_sim).mean()


contrastive_loss_vectorized(I, J).mean()


CPU times: user 1.54 ms, sys: 1.49 ms, total: 3.03 ms
Wall time: 1.89 ms


tensor(1.1327)

## Pytorch Lighntning Implementation

In [1]:
# import cv2
import numpy as np
from typing import Optional

import torch
from torch import nn
from torch.nn import functional as F
import torchvision.transforms as transforms
from torch.optim import Adam

import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks import LearningRateMonitor

from pl_bolts.models.self_supervised.resnets import resnet50_bn
from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR
from pl_bolts.datamodules import CIFAR10DataModule, STL10DataModule, ImagenetDataModule
from pl_bolts.metrics import mean, accuracy

from pl_bolts.models.self_supervised.evaluator import Flatten
from pl_bolts.transforms.dataset_normalizations import cifar10_normalization, stl10_normalization
from pl_bolts.transforms.dataset_normalizations import imagenet_normalization
from pl_bolts.optimizers import LARSWrapper

ImportError: cannot import name 'resnet50_bn' from 'pl_bolts.models.self_supervised.resnets' (/Users/nishant/opt/miniconda3/envs/ptlite/lib/python3.9/site-packages/pl_bolts/models/self_supervised/resnets.py)

In [1]:
def add_two_numbers(x, y):
    return x + y

