# Contrastive Loss

---
In this session, we are going to implement the SimCLR loss function (https://arxiv.org/abs/2002.05709).

This follows the InfoNCE loss, i.e., uses two different augmented versions of the same image as positive pair and the other images in the batch as negative samples, and the batch construction of the N-pair-mc loss.


In [111]:
import os
import pandas as pd

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

from torchvision.io import read_image

import torchvision
import torchvision.models as models
import torchvision.transforms as transforms

from PIL import Image

In [112]:
class CustomImageDatasetBis(Dataset):
    def __init__(self, data, targets=None, transform=None, target_transform=None):
        self.imgs = data # Tensore di tutte le immagini
        self.targets = targets
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.imgs)

    def __getitem__(self, idx):
        img = self.imgs[idx] # Sampling randomico di emlementi del dataset
        if isinstance(img, str): # Può capitare che il dataset sia salvato come stringhe/path (da usare quando non è possibile salvarsi tutto il tensore del dataset)
          img = read_image(img_path) # Fuzione di Torchvision, trova un'immaigne dal path fornito
        else:
          img = Image.fromarray(img.astype('uint8'), 'RGB')
        label = self.targets[idx] # Non utile nel caso di self-supervised ovviamente
        if self.transform:
            img1 = self.transform(img) # Utilizzo le trasformazioni
            img2 = self.transform(img)
            # img = self.transform(img)  Già così genero due immagini augmented diverse, siccome le funzioni che trasformano sono randomiche (TODO, rivedi le variabili)
        if self.target_transform:
            label1 = self.target_transform(label)
            label2 = self.target_transform(label)
        else:
            label1 = label
            label2 = label
        return img1, img2# , label1, label2 # Concateno immaigni e labels
    
# simclr DA pipeline
s=1
size=32
color_jitter = transforms.ColorJitter(0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s)
transform = transforms.Compose([transforms.RandomResizedCrop(size=size),
                                  transforms.RandomHorizontalFlip(),
                                  transforms.RandomApply([color_jitter], p=0.8),
                                  transforms.RandomGrayscale(p=0.2),
                                  transforms.GaussianBlur(kernel_size=3),
                                  transforms.ToTensor()])
    

In [113]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class Backbone(nn.Module):  # emulates a smaller resnet18
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.global_pool = nn.AdaptiveAvgPool2d((1, 1))  # Outputs 256-dim vector

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool(F.relu(self.conv2(x)))
        x = self.pool(F.relu(self.conv3(x)))
        x = self.global_pool(x)
        x = x.view(x.size(0), -1)  # Flatten to (batch_size, 256)
        return x


class SiameseNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.backbone = Backbone()
        self.projection = nn.Sequential(
            nn.Linear(256, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(inplace=True),
            nn.Linear(256, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(inplace=True),
            nn.Linear(128, 10)
        )

    def forward(self, x1, x2):
        x1 = self.backbone(x1)

        x2 = self.backbone(x2)

        x1 = F.normalize(x1, dim=1)
        x2 = F.normalize(x2, dim=1)
    
        return torch.cat((x1,x2), dim=0)

# Check output
a = SiameseNet()
input1 = torch.randn(5, 3, 32, 32)
input2 = torch.randn(5, 3, 32, 32)
output = a(input1, input2)

print("Output shapes:", output.shape)




Output shapes: torch.Size([10, 256])


In [None]:
import time
# non abbiamo mai fattoo operazioni in place- io genero tutto su var di appoggio e poi il risultato lo metto su una nuova var
#vai operazioni in place per fare boost di performance


#bath size + grande possibile ok ma devo fare attensione a ottimizzare codice, arrivi agile a 4096x4096 e occupi moltissima memoria
# evita di istaziare due matrici


#poi le matrici delle maschere è molto sparso


#1. cosa fa loss
#2 indica dove sono positive negative ecc, aiutati con grafici x far capire cosa succede nel codice
# no computazione o training
import numpy as np
class ContrastiveLossFor(nn.Module):
    def __init__(self, temperature=0.07):
        super().__init__()
        self.temperature = temperature
        self.criterion = torch.nn.CrossEntropyLoss()

    def forward(self, features):
        features = F.normalize(features, dim=1)
        # normalize features to later compute cosine distance/similarity btw them

        # compute the similarity matrix btw features
        # (consider that feature are normalized! so the cosine similarity is ...)
        #! similitudine tra features si fa con dot product

        similarity_matrix = torch.matmul(features, features.T)
        # print(similarity_matrix.shape) # expected 128*128
        start = time.time()
        
        batch_size=features.shape[0] // 2#easy version
        logits = torch.zeros(2*batch_size, 2*batch_size-1) # expected 128*127 -> 2N * 2N-1
        for idx, val in enumerate(similarity_matrix):
            # print( 'val shape:', val.shape)
            row = torch.zeros(2*batch_size-1) # 127
            pos_idx = idx + batch_size if idx < batch_size else idx-batch_size
            row[0] = val[pos_idx]
            row[1:]=torch.tensor([v for i,v in enumerate(val) if i!= idx and i!=pos_idx])
            logits[idx]=row

        logits = logits/self.temperature
        gt = torch.zeros(logits.shape[0],dtype=torch.long)
        loss = self.criterion(logits,gt)
        end = time.time()
        print(' Loss execution time: ', end - start)


        ## TODO

        # create the logits tensor where:
        #   - in the first position there is the similarity of the positive pair
        #   - in the other 2N-1 positions there are the similarity w negatives
        # the shape of the tensor need to be 2Nx2N-1, with N is the batch size
        # in diagonale ho elemento*elemento, va rimossa
        #  devi identificare i positicw da mandare a ground truth
        # , metti similitudine coppia positive e similitudine positive negative

        # to compute the contrastive loss using the CE loss, we just need to
        # specify where is the similarity of the positive pair in the logits tensor
        # since we put in the first position we create a gt of all zeros
        # N.B.: this is just one of the possible implementations!
        return loss
    
class ContrastiveLossVect(nn.Module):
    def __init__(self, temperature=0.07):
        super().__init__()
        self.temperature = temperature
        self.criterion = torch.nn.CrossEntropyLoss()
    
    def forward(self, features):
        features = F.normalize(features, dim=1)
        similarity_matrix = torch.matmul(features,features.T)

        start = time.time()
        labels = torch.cat([torch.arange(features.shape[0]//2) for i in range(2)], dim=0)
        labels = (labels.unsqueeze(0) == labels.unsqueeze(1)).float()
        mask = torch.eye(labels.shape[0], dtype=torch.bool)
        labels = labels[~mask].view(labels.shape[0],-1)
        similarity_matrix = similarity_matrix[~mask].view(similarity_matrix.shape[0],-1)
        positives = similarity_matrix[labels.bool()].view(labels.shape[0],-1)
        negatives = similarity_matrix[~labels.bool()].view(labels.shape[0],-1)
        logits = torch.cat([positives,negatives],dim=1)
        logits = logits / self.temperature

        gt = torch.zeros(logits.shape[0],dtype=torch.long,device=logits.device)
        end = time.time()
        print('Loss vect execution time:', end - start)
        return self.criterion(logits,gt)    
    
class ContrastiveLossOpt(nn.Module):
    def __init__(self, temperature=0.07):
        super().__init__()
        self.temperature = temperature
        self.criterion = torch.nn.CrossEntropyLoss()
    
    def forward(self, features):
        features = F.normalize(features, dim=1)
        similarity_matrix = torch.matmul(features,features.T)

        start = time.time()

        # TODO optimize me with in place operations
        labels = torch.cat([torch.arange(features.shape[0]//2) for i in range(2)], dim=0)
        labels = (labels.unsqueeze(0) == labels.unsqueeze(1)).float()
        labels = labels[~torch.eye(labels.shape[0], dtype=torch.bool)].view(labels.shape[0],-1)

        similarity_matrix = similarity_matrix[~torch.eye(labels.shape[0], dtype=torch.bool)].view(similarity_matrix.shape[0],-1)

        logits = torch.cat([
            similarity_matrix[labels.bool()].view(labels.shape[0],-1), # positive
            similarity_matrix[~labels.bool()].view(labels.shape[0],-1) # negative
                            ],
                            dim=1)
        logits = logits / self.temperature

        gt = torch.zeros(logits.shape[0],dtype=torch.long, device=logits.device)

        end = time.time()
        print('Loss opt execution time:', end - start)
        return self.criterion(logits,gt)    
    



In [115]:
data = torchvision.datasets.CIFAR10(root='./data', train=True, download=True)   
trainset = CustomImageDatasetBis(data.data, data.targets, transform=transform)
dataloader = DataLoader(trainset, batch_size=64, shuffle=True)
x1,x2 = next(iter(dataloader))
model = SiameseNet().to('cuda')
x1,x2 = x1.to('cuda'), x2.to('cuda')
features = model(x1, x2)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
        # normalize features to later compute cosine distance/similarity btw them
features = F.normalize(features, dim=1)
        # compute the similarity matrix btw features
        # (consider that feature are normalized! so the cosine similarity is ...)
        #! similitudine tra features si fa con dot product

similarity_matrix = torch.matmul(features, features.T)
mask = ~torch.eye(similarity_matrix.size(0), dtype=torch.bool, device=similarity_matrix.device)
off_diagonal_elements = similarity_matrix[mask]
print(off_diagonal_elements.shape)




Files already downloaded and verified
torch.Size([16256])


Let's now use the Dataset which creates the two augmented views for each image and the Siamese Network from the past lab session [1](https://colab.research.google.com/drive/1NJwAFbRiD4MdwWf__6P2Lm0xYk_DNdVu?usp=sharing) and [2](https://colab.research.google.com/drive/1AMkh0q8L5nJScx7v6cMWoK336zqOqDY6?usp=sharing) and create a training loop

In [None]:
data = torchvision.datasets.CIFAR10(root='./data', train=True, download=True)   
trainset = CustomImageDatasetBis(data.data, data.targets, transform=transform)
dataloader = DataLoader(trainset, batch_size=64, shuffle=True, pin_memory=True)# fast computation

model = SiameseNet().to('cuda')
optimizer = optim.Adam(model.parameters(), lr=1e-3)
criterion = ContrastiveLossOpt()

for idx, (img1,img2) in enumerate(dataloader):
    img1  = img1.to('cuda')
    img2 = img2.to('cuda')

    optimizer.zero_grad()
    features = model(img1, img2)
    loss = criterion(features) #criterion(x1,x2, target)
    loss.backward()
    optimizer.step()

    if idx == 3:
        break

Files already downloaded and verified
Loss vect execution time: 0.012261629104614258
Loss vect execution time: 0.007040500640869141
Loss vect execution time: 0.0055620670318603516
Loss vect execution time: 0.007747650146484375


0.0009448528289794922

Files already downloaded and verified


Loss vect execution time: 0.00289154052734375

Loss vect execution time: 0.0009870529174804688

Loss vect execution time: 0.0019037723541259766

Loss vect execution time: 0.0010981559753417969
