# SimCLR Implementation and Evaluation on CIFAR-10

This notebook implements the SimCLR algorithm, trains it on the CIFAR-10 dataset, and evaluates the learned representations using Linear Probing and K-Nearest Neighbors (KNN) classification.


In [1]:
# Import necessary libraries
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from sklearn.neighbors import KNeighborsClassifier
import numpy as np
import torch.nn.functional as F
import torchvision.models as models

# Importing necessary libraries and modules for the implementation.

### Execution Timers

In [2]:

# Flag to enable or disable timers
enable_timers = True

import time

class Timer:
    def __enter__(self):
        if enable_timers:
            self.start = time.time()
        return self

    def __exit__(self, *args):
        if enable_timers:
            self.end = time.time()
            self.interval = self.end - self.start
            print(f"Elapsed time: {self.interval:.2f} seconds")
    

# Importing necessary libraries and modules for the implementation.

## Load CIFAR-10 Dataset

Load the CIFAR-10 training and test datasets.


In [3]:
from data_aug.contrastive_learning_dataset import ContrastiveLearningDataset

dataset = ContrastiveLearningDataset(root_folder='data')
train_dataset = dataset.get_dataset('cifar10', 2)
train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=512, shuffle=True,
        num_workers=8, pin_memory=True, drop_last=True)


Files already downloaded and verified


## Define SimCLR Encoder and Projection Head

Create the encoder model and projection head using ResNet18 as the base architecture.


In [4]:
class ResNetSimCLR(nn.Module):

    def __init__(self, base_model, out_dim):
        super(ResNetSimCLR, self).__init__()
        self.resnet_dict = {"resnet18": models.resnet18(pretrained=False, num_classes=out_dim),
                            "resnet50": models.resnet50(pretrained=False, num_classes=out_dim)}

        self.backbone = self._get_basemodel(base_model)
        dim_mlp = self.backbone.fc.in_features

        # add mlp projection head
        self.backbone.fc = nn.Sequential(nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.backbone.fc)

    def _get_basemodel(self, model_name):
        try:
            model = self.resnet_dict[model_name]
        except:
            raise ("Invalid model name. Check the config file and pass one of: resnet18 or resnet50")

        return model

    def forward(self, x):
        return self.backbone(x)

## Define Contrastive Loss

Implement the contrastive loss function used by SimCLR.


In [5]:
def info_nce_loss(features, temperature=0.5):
        batch_size = features.shape[0] // 2 # 2 views per batch
        
        labels = torch.cat([torch.arange(batch_size) for i in range(2)], dim=0)
        labels = (labels.unsqueeze(0) == labels.unsqueeze(1)).float()
        device = (torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu'))
        labels = labels.to(device)

        features = F.normalize(features, dim=1)

        similarity_matrix = torch.matmul(features, features.T)
        # assert similarity_matrix.shape == (
        #     self.args.n_views * self.args.batch_size, self.args.n_views * self.args.batch_size)
        # assert similarity_matrix.shape == labels.shape

        # discard the main diagonal from both: labels and similarities matrix
        mask = torch.eye(labels.shape[0], dtype=torch.bool).to(device)
        labels = labels[~mask].view(labels.shape[0], -1)
        similarity_matrix = similarity_matrix[~mask].view(similarity_matrix.shape[0], -1)
        # assert similarity_matrix.shape == labels.shape

        # select and combine multiple positives
        positives = similarity_matrix[labels.bool()].view(labels.shape[0], -1)

        # select only the negatives
        negatives = similarity_matrix[~labels.bool()].view(similarity_matrix.shape[0], -1)

        logits = torch.cat([positives, negatives], dim=1)
        labels = torch.zeros(logits.shape[0], dtype=torch.long).to(device)

        logits = logits / temperature
        return logits, labels

## Training SimCLR

Train the SimCLR model using the contrastive loss and augmented image pairs from CIFAR-10.


In [None]:
from torch.utils.tensorboard import SummaryWriter
import os
from tqdm import tqdm
import logging
from utils import accuracy, save_checkpoint


with Timer():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Training with gpu: {device}.")
    # Initialize optimizer and loss criterion
    model = ResNetSimCLR(base_model='resnet18', out_dim=128)
    model = model.to(device)
    lr = 3e-4
    weight_decay = 1e-4
    optimizer = torch.optim.Adam(model.parameters(), lr, weight_decay=weight_decay)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=len(train_loader), eta_min=0,
                                                           last_epoch=-1)
    writer = SummaryWriter()
    logging.basicConfig(filename=os.path.join(writer.log_dir, 'training.log'), level=logging.DEBUG)
    criterion = torch.nn.CrossEntropyLoss().to(device)
    # Set number of training epochs
    epochs = 800
    log_every_n_epochs = 1
    logging.info(f"Start SimCLR training for {epochs} epochs.")
    logging.info(f"Training with gpu: {device}.")
    best_acc = 0
    for epoch_counter in range(epochs):
        loss_epoch = 0
        for images, _ in tqdm(train_loader):
            images = torch.cat(images, dim=0)

            images = images.to(device)

            # with autocast(enabled=fp16_precision):
            features = model(images)
            logits, labels = info_nce_loss(features)
            loss = criterion(logits, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            loss_epoch += loss.item()
            # scaler.scale(loss).backward()
            # scaler.step(self.optimizer)
            # scaler.update()
        avg_loss = loss_epoch / len(train_loader)
        # print(f"Epoch {epoch_counter}:\tLoss: {avg_loss}")
        # every log_every_n_epochs log epoch loss and accuracy
        if epoch_counter % log_every_n_epochs == 0:
            top1, top5 = accuracy(logits, labels, topk=(1, 5))
            writer.add_scalar('loss', avg_loss, global_step=epoch_counter)
            writer.add_scalar('acc/top1', top1[0], global_step=epoch_counter)
            writer.add_scalar('acc/top5', top5[0], global_step=epoch_counter)
            writer.add_scalar('learning_rate', scheduler.get_last_lr()[0], global_step=epoch_counter)
            if top1[0] > best_acc:
                best_acc = top1[0]
                save_checkpoint({
                    'epoch': epoch_counter,
                    'arch': 'resnet18',
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                }, is_best=True, filename=os.path.join(writer.log_dir, f'checkpoint_best.pth.tar'))


        # warmup for the first 10 epochs
        if epoch_counter >= 10:
            scheduler.step()
        logging.debug(f"Epoch: {epoch_counter}\tLoss: {loss}\tTop1 accuracy: {top1[0]}")

    logging.info("Training has finished.")
    # save model checkpoints
    checkpoint_name = 'checkpoint_{:04d}.pth.tar'.format(epochs)
    save_checkpoint({
        'epoch': epochs,
        'arch': 'resnet18',
        'state_dict': model.state_dict(),
        'optimizer': optimizer.state_dict(),
    }, is_best=False, filename=os.path.join(writer.log_dir, checkpoint_name))
    logging.info(f"Model checkpoint and metadata has been saved at {writer.log_dir}.")

Training with gpu: cuda.


100%|██████████| 97/97 [00:05<00:00, 16.48it/s]
100%|██████████| 97/97 [00:06<00:00, 15.77it/s]
100%|██████████| 97/97 [00:06<00:00, 15.95it/s]
100%|██████████| 97/97 [00:05<00:00, 16.44it/s]
100%|██████████| 97/97 [00:05<00:00, 16.46it/s]
100%|██████████| 97/97 [00:05<00:00, 16.50it/s]
100%|██████████| 97/97 [00:05<00:00, 16.40it/s]
100%|██████████| 97/97 [00:05<00:00, 16.25it/s]
100%|██████████| 97/97 [00:05<00:00, 16.49it/s]
100%|██████████| 97/97 [00:06<00:00, 15.82it/s]
100%|██████████| 97/97 [00:05<00:00, 16.52it/s]
100%|██████████| 97/97 [00:06<00:00, 15.56it/s]
100%|██████████| 97/97 [00:05<00:00, 16.83it/s]
100%|██████████| 97/97 [00:05<00:00, 16.59it/s]
100%|██████████| 97/97 [00:05<00:00, 16.75it/s]
100%|██████████| 97/97 [00:05<00:00, 16.45it/s]
100%|██████████| 97/97 [00:06<00:00, 15.80it/s]
100%|██████████| 97/97 [00:05<00:00, 16.67it/s]
100%|██████████| 97/97 [00:06<00:00, 15.84it/s]
100%|██████████| 97/97 [00:05<00:00, 16.36it/s]
100%|██████████| 97/97 [00:06<00:00, 15.

Load the model checkpoint and evaluate the learned representations using Linear Probing and KNN classification.

In [9]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = models.resnet18(pretrained=False, num_classes=10).to(device)
# Load the checkpoint
checkpoint_path = 'runs/Sep26_17-15-26_cpsadmin-Z790-AORUS-ELITE-AX/checkpoint_best.pth.tar'
checkpoint = torch.load(checkpoint_path)
state_dict = checkpoint['state_dict']
# model.load_state_dict(state_dict)

for k in list(state_dict.keys()):
  if k.startswith('backbone.'):
    if k.startswith('backbone') and not k.startswith('backbone.fc'):
      # remove prefix
      state_dict[k[len("backbone."):]] = state_dict[k]
  del state_dict[k]
log = model.load_state_dict(state_dict, strict=False)
assert log.missing_keys == ['fc.weight', 'fc.bias']

In [10]:
# freeze all layers but the last fc
for name, param in model.named_parameters():
    if name not in ['fc.weight', 'fc.bias']:
        param.requires_grad = False

parameters = list(filter(lambda p: p.requires_grad, model.parameters()))
assert len(parameters) == 2  # fc.weight, fc.bias

In [11]:
from torchvision import datasets
def get_cifar10_data_loaders(download, shuffle=False, batch_size=256):
  train_dataset = datasets.CIFAR10('./data', train=True, download=download,
                                  transform=transforms.ToTensor())

  train_loader = DataLoader(train_dataset, batch_size=batch_size,
                            num_workers=0, drop_last=False, shuffle=shuffle)
  
  test_dataset = datasets.CIFAR10('./data', train=False, download=download,
                                  transform=transforms.ToTensor())

  test_loader = DataLoader(test_dataset, batch_size=2*batch_size,
                            num_workers=10, drop_last=False, shuffle=shuffle)
  return train_loader, test_loader


In [12]:
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4, weight_decay=0.0008)
criterion = torch.nn.CrossEntropyLoss().to(device)
train_loader, test_loader = get_cifar10_data_loaders(download=True)

Files already downloaded and verified
Files already downloaded and verified


In [14]:
from utils import accuracy
epochs = 100
for epoch in range(epochs):
  top1_train_accuracy = 0
  for counter, (x_batch, y_batch) in enumerate(train_loader):
    x_batch = x_batch.to(device)
    y_batch = y_batch.to(device)

    logits = model(x_batch)
    loss = criterion(logits, y_batch)
    top1 = accuracy(logits, y_batch, topk=(1,))
    top1_train_accuracy += top1[0]

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

  top1_train_accuracy /= (counter + 1)
  top1_accuracy = 0
  top5_accuracy = 0
  for counter, (x_batch, y_batch) in enumerate(test_loader):
    x_batch = x_batch.to(device)
    y_batch = y_batch.to(device)

    logits = model(x_batch)
  
    top1, top5 = accuracy(logits, y_batch, topk=(1,5))
    top1_accuracy += top1[0]
    top5_accuracy += top5[0]
  
  top1_accuracy /= (counter + 1)
  top5_accuracy /= (counter + 1)
  print(f"Epoch {epoch}:\tTrain Accuracy: {top1_train_accuracy.item():.2f}\tTest Accuracy: {top1_accuracy.item():.2f}\tTest Top-5 Accuracy: {top5_accuracy.item():.2f}")
  
  

Epoch 0:	Train Accuracy: 60.43	Test Accuracy: 69.10	Test Top-5 Accuracy: 96.72
Epoch 1:	Train Accuracy: 70.81	Test Accuracy: 70.09	Test Top-5 Accuracy: 97.36
Epoch 2:	Train Accuracy: 71.92	Test Accuracy: 70.94	Test Top-5 Accuracy: 97.63
Epoch 3:	Train Accuracy: 72.65	Test Accuracy: 71.42	Test Top-5 Accuracy: 97.78
Epoch 4:	Train Accuracy: 73.22	Test Accuracy: 72.00	Test Top-5 Accuracy: 97.90
Epoch 5:	Train Accuracy: 73.57	Test Accuracy: 72.50	Test Top-5 Accuracy: 97.96
Epoch 6:	Train Accuracy: 73.92	Test Accuracy: 72.64	Test Top-5 Accuracy: 97.96
Epoch 7:	Train Accuracy: 74.20	Test Accuracy: 72.93	Test Top-5 Accuracy: 98.06
Epoch 8:	Train Accuracy: 74.46	Test Accuracy: 73.11	Test Top-5 Accuracy: 98.12
Epoch 9:	Train Accuracy: 74.72	Test Accuracy: 73.39	Test Top-5 Accuracy: 98.15
Epoch 10:	Train Accuracy: 74.91	Test Accuracy: 73.48	Test Top-5 Accuracy: 98.23
Epoch 11:	Train Accuracy: 75.06	Test Accuracy: 73.69	Test Top-5 Accuracy: 98.24
Epoch 12:	Train Accuracy: 75.21	Test Accuracy: 73.