# SimCLR Implementation and Evaluation on CIFAR-10

This notebook implements the SimCLR algorithm, trains it on the CIFAR-10 dataset, and evaluates the learned representations using Linear Probing and K-Nearest Neighbors (KNN) classification.


In [1]:
# Import necessary libraries
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from sklearn.neighbors import KNeighborsClassifier
import numpy as np
    

# Importing necessary libraries and modules for the implementation.

### Choose the Contrastive Loss

In [2]:

# Choose the contrastive loss: 'nt_xent' or 'contrastive'
loss_choice = 'nt_xent'
    

### Execution Timers

In [3]:

# Flag to enable or disable timers
enable_timers = True

import time

class Timer:
    def __enter__(self):
        if enable_timers:
            self.start = time.time()
        return self

    def __exit__(self, *args):
        if enable_timers:
            self.end = time.time()
            self.interval = self.end - self.start
            print(f"Elapsed time: {self.interval:.2f} seconds")
    

# Importing necessary libraries and modules for the implementation.

## Data Augmentation

Define the data augmentation pipeline for SimCLR, including random cropping, color jittering, and random flipping.


In [4]:
def simclr_augmentation(img_size):
    color_jitter = transforms.ColorJitter(0.8, 0.8, 0.8, 0.2)
    transform = transforms.Compose([
        transforms.RandomResizedCrop(size=img_size),
        transforms.RandomHorizontalFlip(),
        transforms.RandomApply([color_jitter], p=0.8),
        transforms.RandomGrayscale(p=0.2),
        transforms.ToTensor(),
    ])
    return transform

cifar10_transform = simclr_augmentation(32)


# Define a function for SimCLR data augmentation. This includes random resized cropping, random horizontal flip, color jittering, and grayscale conversion.

## Load CIFAR-10 Dataset

Load the CIFAR-10 training and test datasets.


In [10]:
with Timer():
    train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=cifar10_transform)
    test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=torchvision.transforms.ToTensor())
    

# Load the CIFAR-10 dataset. Training data will undergo the SimCLR augmentation while test data will only be transformed to tensors.

Files already downloaded and verified
Files already downloaded and verified
Elapsed time: 1.05 seconds


## Define SimCLR Encoder and Projection Head

Create the encoder model and projection head using ResNet18 as the base architecture.


In [36]:
with Timer():
    resnet = torchvision.models.resnet18(pretrained=False)
    
    class SimCLREncoder(nn.Module):
        def __init__(self, base_encoder, projection_dim=128):
            super(SimCLREncoder, self).__init__()
            self.resnet = base_encoder
            self.projection_head = nn.Sequential(
                nn.Linear(1000, 512, bias=False),
                nn.ReLU(),
                nn.Linear(512, projection_dim, bias=False)
            )
        
        def forward(self, x):
            x = self.resnet(x)
            x = self.projection_head(x)
            return x
    
    encoder = SimCLREncoder(resnet).to("cuda")
    

# Define the SimCLR encoder which consists of a base encoder (ResNet18 in this case) and a projection head.

Elapsed time: 0.12 seconds


## Define Contrastive Loss

Implement the contrastive loss function used by SimCLR.


In [37]:
class ContrastiveLoss(nn.Module):
    def __init__(self, temperature=0.5):
        super(ContrastiveLoss, self).__init__()
        self.temperature = temperature

    def forward(self, z_i, z_j):
        norm_i = torch.norm(z_i, dim=1).reshape(-1, 1)
        norm_j = torch.norm(z_j, dim=1).reshape(-1, 1)
        z_i = z_i / norm_i
        z_j = z_j / norm_j

        sim_ij = torch.mm(z_i, z_j.T) / self.temperature
        sim_ji = torch.mm(z_j, z_i.T) / self.temperature

        loss_matrix = - torch.log_softmax(sim_ij, dim=1)
        loss = loss_matrix.sum(dim=1).mean()
        return loss


# Define a generic contrastive loss. This computes the similarity between positive pairs and contrasts it with negative pairs.

In [38]:
class NTXentLoss(nn.Module):
    def __init__(self, temperature=0.5, device="cuda"):
        super(NTXentLoss, self).__init__()
        self.temperature = temperature
        self.device = device
        self.criterion = nn.CrossEntropyLoss(reduction="sum")
        self.similarity_f = nn.CosineSimilarity(dim=2)

    def forward(self, z_i, z_j):
        z = torch.cat([z_i, z_j], dim=0)
        n = z.size(0)
        device = z.device
        # Compute similarity matrix
        sim = self.similarity_f(z.unsqueeze(1), z.unsqueeze(0)) / self.temperature

        # Exclude the main diagonal from the similarity computation
        mask = torch.eye(n, device=self.device).bool()
        sim = sim.to(device)
        mask = mask.to(device)
        sim.masked_fill_(mask, float('-inf'))
        positive_samples_i = sim[:n // 2, n // 2:].diag().unsqueeze(-1)
        positive_samples_j = sim[n // 2:, :n // 2].diag().unsqueeze(-1)
        positive_samples = torch.cat([positive_samples_i, positive_samples_j], dim=0)
        negatives = sim.masked_select(mask == 0).reshape(n, -1)
        
        # Logits and labels for the loss
        logits = torch.cat((positive_samples, negatives), dim=1)
        logits /= self.temperature
        # Symmetric labels: {0,...,0,1,...,1}
        labels = torch.zeros(logits.shape[0], dtype=torch.long).to(self.device)
        logits = logits.to(device)
        loss = self.criterion(logits, labels)
        loss /= n

        return loss

# Define the NT-Xent loss used in SimCLR. This loss contrasts the similarity of positive pairs with that of negative pairs and scales it by a temperature parameter.

## Training SimCLR

Train the SimCLR model using the contrastive loss and augmented image pairs from CIFAR-10.


In [45]:
# Create data loaders with SimCLR augmentations
train_loader_i = DataLoader(train_dataset, batch_size=512, shuffle=True, num_workers=4)
train_loader_j = DataLoader(train_dataset, batch_size=512, shuffle=True, num_workers=4)

In [None]:
from tqdm import tqdm

with Timer():
    # Initialize optimizer and loss criterion
    learning_rate = 1e-4
    optimizer = optim.Adam(encoder.parameters(), lr=learning_rate)
    sceduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=len(train_loader_i), eta_min=0, last_epoch=-1)
    if loss_choice == 'nt_xent':
        criterion = NTXentLoss()
    else:
        criterion = ContrastiveLoss()
    
    # Set number of training epochs
    epochs = 50
    
    pbar = tqdm(range(epochs))
    # Training loop
    for epoch in pbar:
        for (x_i, _), (x_j, _) in zip(train_loader_i, train_loader_j):
            x_i = x_i.cuda()
            x_j = x_j.cuda()
            # Pass the inputs through the encoder to get embeddings
            z_i = encoder(x_i)
            z_j = encoder(x_j)
            
            # Compute the loss
            loss = criterion(z_i, z_j)
            
            # Backpropagation and optimization step
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            # update the progress bar with the loss value
            pbar.set_description(f"Loss: {loss.item():.5f}")
        sceduler.step()
   
# save the trained model
torch.save(encoder.state_dict(), f"simclr_encoder_e{epoch}_lr{learning_rate}.pth")

Loss: 6.93149:  16%|█▌        | 8/50 [01:54<09:06, 13.01s/it]

## Linear Probing

Evaluate the learned representations using Linear Probing. A linear classifier is trained on top of the frozen encoder and its accuracy is reported on the test set.


In [26]:
class LinearProbe(nn.Module):
    def __init__(self, encoder, num_classes=10):
        super(LinearProbe, self).__init__()
        self.encoder = encoder
        self.classifier = nn.Linear(512, num_classes)

    def forward(self, x):
        x = self.encoder(x)
        x = self.classifier(x)
        return x

linear_probe = LinearProbe(encoder)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(linear_probe.classifier.parameters(), lr=0.001)

# Training the linear probe
for epoch in range(epochs):
    for images, labels in DataLoader(train_dataset, batch_size=256, shuffle=True):
        outputs = linear_probe(images)
        loss = criterion(outputs, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

# Evaluate linear probe
correct, total = 0, 0
with torch.no_grad():
    for images, labels in DataLoader(test_dataset, batch_size=256):
        outputs = linear_probe(images)
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
print(f"Linear Probe Accuracy: {100 * correct / total:.2f}%")


# Initialize the optimizer (Adam in this case) and set the loss criterion based on the loss_choice flag. Then, train the SimCLR model using the chosen loss.

RuntimeError: mat1 and mat2 shapes cannot be multiplied (256x128 and 512x10)

## K-Nearest Neighbors (KNN) Classification

Evaluate the learned representations using KNN classification. For each sample in the test set, we find the 'K' nearest samples from the training set in the embedding space and assign the label based on a majority vote from these neighbors.


In [None]:
with Timer():
    train_features = []
    train_labels = []
    with torch.no_grad():
        for images, labels in DataLoader(train_dataset, batch_size=256):
            features = encoder(images)
            train_features.append(features.cpu().numpy())
            train_labels.append(labels.cpu().numpy())
    
    train_features = np.concatenate(train_features, axis=0)
    train_labels = np.concatenate(train_labels, axis=0)
    
    knn_classifier = KNeighborsClassifier(n_neighbors=5)
    knn_classifier.fit(train_features, train_labels)
    
    test_features = []
    test_labels = []
    with torch.no_grad():
        for images, labels in DataLoader(test_dataset, batch_size=256):
            features = encoder(images)
            test_features.append(features.cpu().numpy())
            test_labels.append(labels.cpu().numpy())
    
    test_features = np.concatenate(test_features, axis=0)
    test_labels = np.concatenate(test_labels, axis=0)
    
    knn_accuracy = knn_classifier.score(test_features, test_labels)
    print(f"KNN Accuracy: {100 * knn_accuracy:.2f}%")

### GPU-based KNN Implementation


In [None]:
class GPUKNN:
    def __init__(self, k=5):
        self.k = k
        self.train_features = None
        self.train_labels = None
    
    def fit(self, train_features, train_labels):
        self.train_features = self._normalize(train_features)
        self.train_labels = train_labels
    
    def predict(self, test_features):
        test_features = self._normalize(test_features)
        similarity_matrix = torch.mm(test_features, self.train_features.t())
        _, top_k_indices = similarity_matrix.topk(self.k, dim=1, largest=True, sorted=True)
        top_k_labels = torch.gather(self.train_labels.expand(test_features.size(0), -1), 1, top_k_indices)
        predicted_labels = torch.mode(top_k_labels, dim=1).values
        return predicted_labels
    
    def _normalize(self, x):
        return x / x.norm(p=2, dim=1, keepdim=True)
    

# Define a GPU-based K-Nearest Neighbors (KNN) classifier. This classifier uses matrix multiplication to compute similarities between feature vectors, making it efficient on GPU.

In [None]:
with Timer():
    gpu_knn_classifier = GPUKNN(k=5)
    gpu_knn_classifier.fit(torch.from_numpy(train_features).cuda(), torch.from_numpy(train_labels).cuda())
    

# Define a GPU-based K-Nearest Neighbors (KNN) classifier. This classifier uses matrix multiplication to compute similarities between feature vectors, making it efficient on GPU.

In [None]:
with Timer():
    test_features = []
    test_labels = []
    with torch.no_grad():
        for images, labels in DataLoader(test_dataset, batch_size=256):
            features = encoder(images).cuda()
            test_features.append(features.cpu().numpy())
            test_labels.append(labels.cpu().numpy())
            
    test_features = np.concatenate(test_features, axis=0)
    test_labels = np.concatenate(test_labels, axis=0)
    
    gpu_knn_accuracy = (gpu_knn_classifier.predict(torch.from_numpy(test_features).cuda()) == torch.from_numpy(test_labels).cuda()).float().mean().item()
    print(f"GPU KNN Accuracy: {100 * gpu_knn_accuracy:.2f}%")
    

## Results


In [None]:
with Timer():
    print(f"Linear Probe Accuracy: {100 * correct / total:.2f}%")
    print(f"KNN Accuracy: {100 * knn_accuracy:.2f}%")
    print(f"GPU KNN Accuracy: {100 * gpu_knn_accuracy:.2f}%")
    