| Credentials |                                  |
|----|----------------------------------|
|Host | Montanuniversitaet Leoben        |
|Web | https://cps.unileoben.ac.at      |
|Mail | cps@unileoben.ac.at              |
|Author | Fotios Lygerakis                 |
|Corresponding Authors | fotios.lygerakis@unileoben.ac.at |
|Last edited | 28.09.2023                       |

# SimCLR Implementation and Evaluation on CIFAR-10

This notebook implements the SimCLR algorithm, trains it on the CIFAR-10 dataset, and evaluates the learned representations using Linear Probing.


In [1]:
# Import necessary libraries
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from sklearn.neighbors import KNeighborsClassifier
import numpy as np
import torch.nn.functional as F
import torchvision.models as models

# Importing necessary libraries and modules for the implementation.

### Execution Timers

In [2]:

# Flag to enable or disable timers
enable_timers = True

import time

class Timer:
    def __enter__(self):
        if enable_timers:
            self.start = time.time()
        return self

    def __exit__(self, *args):
        if enable_timers:
            self.end = time.time()
            self.interval = self.end - self.start
            print(f"Elapsed time: {self.interval:.2f} seconds")
    

# Importing necessary libraries and modules for the implementation.

## Load CIFAR-10 Dataset

Load the CIFAR-10 training and test datasets.


In [3]:
from data_aug.contrastive_learning_dataset import ContrastiveLearningDataset

dataset = ContrastiveLearningDataset(root_folder='data')
train_dataset = dataset.get_dataset('cifar10_train', 2, type='moco')

train_loader = DataLoader(train_dataset, batch_size=512, shuffle=True,
                              num_workers=8, pin_memory=True, drop_last=True
                              )
memory_dataset = dataset.get_dataset('cifar10_memory', 2, type='moco')
memory_loader = DataLoader(memory_dataset, batch_size=512, shuffle=False,
                              num_workers=8, pin_memory=True, drop_last=False
                              )
test_dataset = dataset.get_dataset('cifar10_test', 2, type='moco')
test_loader = DataLoader(test_dataset, batch_size=512, shuffle=False,
                              num_workers=8, pin_memory=True, drop_last=False
                              )

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


## Define SimCLR Encoder and Projection Head

Create the encoder model and projection head using ResNet18 as the base architecture.


In [4]:
from torchvision.models import resnet


class ModelBase(nn.Module):
    """
    Common CIFAR ResNet recipe.
    Comparing with ImageNet ResNet recipe, it:
    (i) replaces conv1 with kernel=3, str=1
    (ii) removes pool1
    """
    def __init__(self, feature_dim=128, arch=None, bn_splits=16):
        super(ModelBase, self).__init__()

        # use split batchnorm
        # norm_layer = partial(SplitBatchNorm, num_splits=bn_splits) if bn_splits > 1 else nn.BatchNorm2d
        norm_layer = nn.BatchNorm2d
        resnet_arch = getattr(resnet, arch)
        net = resnet_arch(num_classes=feature_dim, norm_layer=norm_layer)

        self.net = []
        for name, module in net.named_children():
            if name == 'conv1':
                module = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
            if isinstance(module, nn.MaxPool2d):
                continue
            if isinstance(module, nn.Linear):
                self.net.append(nn.Flatten(1))
            self.net.append(module)

        self.net = nn.Sequential(*self.net)

    def forward(self, x):
        x = self.net(x)
        # note: not normalized here
        return x

## Define Contrastive Loss

Implement the contrastive loss function used by SimCLR.


In [5]:
class ModelMoCo(nn.Module):
    def __init__(self, dim=128, K=4096, m=0.99, T=0.1, arch='resnet18', bn_splits=8, symmetric=True):
        super(ModelMoCo, self).__init__()

        self.K = K
        self.m = m
        self.T = T
        self.symmetric = symmetric

        # create the encoders
        self.encoder_q = ModelBase(feature_dim=dim, arch=arch, bn_splits=bn_splits)
        self.encoder_k = ModelBase(feature_dim=dim, arch=arch, bn_splits=bn_splits)

        for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
            param_k.data.copy_(param_q.data)  # initialize
            param_k.requires_grad = False  # not update by gradient

        # create the queue
        self.register_buffer("queue", torch.randn(dim, K))
        self.queue = nn.functional.normalize(self.queue, dim=0)

        self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))

    @torch.no_grad()
    def _momentum_update_key_encoder(self):
        """
        Momentum update of the key encoder
        """
        for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
            param_k.data = param_k.data * self.m + param_q.data * (1. - self.m)

    @torch.no_grad()
    def _dequeue_and_enqueue(self, keys):
        batch_size = keys.shape[0]

        ptr = int(self.queue_ptr)
        assert self.K % batch_size == 0  # for simplicity

        # replace the keys at ptr (dequeue and enqueue)
        self.queue[:, ptr:ptr + batch_size] = keys.t()  # transpose
        ptr = (ptr + batch_size) % self.K  # move pointer

        self.queue_ptr[0] = ptr

    @torch.no_grad()
    def _batch_shuffle_single_gpu(self, x):
        """
        Batch shuffle, for making use of BatchNorm.
        """
        # random shuffle index
        idx_shuffle = torch.randperm(x.shape[0]).cuda()

        # index for restoring
        idx_unshuffle = torch.argsort(idx_shuffle)

        return x[idx_shuffle], idx_unshuffle

    @torch.no_grad()
    def _batch_unshuffle_single_gpu(self, x, idx_unshuffle):
        """
        Undo batch shuffle.
        """
        return x[idx_unshuffle]

    def contrastive_loss(self, im_q, im_k):
        # compute query features
        q = self.encoder_q(im_q)  # queries: NxC
        q = nn.functional.normalize(q, dim=1)  # already normalized

        # compute key features
        with torch.no_grad():  # no gradient to keys
            # shuffle for making use of BN
            im_k_, idx_unshuffle = self._batch_shuffle_single_gpu(im_k)

            k = self.encoder_k(im_k_)  # keys: NxC
            k = nn.functional.normalize(k, dim=1)  # already normalized

            # undo shuffle
            k = self._batch_unshuffle_single_gpu(k, idx_unshuffle)

        # compute logits
        # Einstein sum is more intuitive
        # positive logits: Nx1
        l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1)
        # negative logits: NxK
        l_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()])

        # logits: Nx(1+K)
        logits = torch.cat([l_pos, l_neg], dim=1)

        # apply temperature
        logits /= self.T

        # labels: positive key indicators
        labels = torch.zeros(logits.shape[0], dtype=torch.long).cuda()
        
        loss = nn.CrossEntropyLoss().cuda()(logits, labels)

        return loss, q, k, logits, labels

    def forward(self, im1, im2):
        """
        Input:
            im_q: a batch of query images
            im_k: a batch of key images
        Output:
            loss
        """

        # update the key encoder
        with torch.no_grad():  # no gradient to keys
            self._momentum_update_key_encoder()

        # compute loss
        if self.symmetric:  # asymmetric loss
            loss_12, _, k2, logits2, labels2 = self.contrastive_loss(im1, im2)
            loss_21, _, k1, logits1, labels1 = self.contrastive_loss(im2, im1)
            loss = loss_12 + loss_21
            k = torch.cat([k1, k2], dim=0)
            logits = torch.cat([logits1, logits2], dim=0)
            labels = torch.cat([labels1, labels2], dim=0)
        else:  # asymmetric loss
            loss, _, k, logits, labels = self.contrastive_loss(im1, im2)

        self._dequeue_and_enqueue(k)

        return loss, logits, labels

## Training SimCLR

Train the SimCLR model using the contrastive loss and augmented image pairs from CIFAR-10.


In [None]:
from torch.utils.tensorboard import SummaryWriter
import os
from tqdm import tqdm
import logging
from utils import accuracy, save_checkpoint
# test using a knn monitor
def test(net, memory_data_loader, test_data_loader, epoch):
    net.eval()
    classes = len(memory_data_loader.dataset.classes)
    total_top1, total_top5, total_num, feature_bank = 0.0, 0.0, 0, []
    with torch.no_grad():
        # generate feature bank
        for data, target in tqdm(memory_data_loader, desc='Feature extracting'):
            feature = net(data.cuda(non_blocking=True))
            feature = F.normalize(feature, dim=1)
            feature_bank.append(feature)
        # [D, N]
        feature_bank = torch.cat(feature_bank, dim=0).t().contiguous()
        # [N]
        feature_labels = torch.tensor(memory_data_loader.dataset.targets, device=feature_bank.device)
        # loop test data to predict the label by weighted knn search
        test_bar = tqdm(test_data_loader)
        for data, target in test_bar:
            data, target = data.cuda(non_blocking=True), target.cuda(non_blocking=True)
            feature = net(data)
            feature = F.normalize(feature, dim=1)

            pred_labels = knn_predict(feature, feature_bank, feature_labels, classes, 200, 0.1)

            total_num += data.size(0)
            total_top1 += (pred_labels[:, 0] == target).float().sum().item()
            test_bar.set_description('Test Epoch: [{}/{}] Acc@1:{:.2f}%'.format(epoch, 200, total_top1 / total_num * 100))

    return total_top1 / total_num * 100

# knn monitor as in InstDisc https://arxiv.org/abs/1805.01978
# implementation follows http://github.com/zhirongw/lemniscate.pytorch and https://github.com/leftthomas/SimCLR
def knn_predict(feature, feature_bank, feature_labels, classes, knn_k, knn_t):
    # compute cos similarity between each feature vector and feature bank ---> [B, N]
    sim_matrix = torch.mm(feature, feature_bank)
    # [B, K]
    sim_weight, sim_indices = sim_matrix.topk(k=knn_k, dim=-1)
    # [B, K]
    sim_labels = torch.gather(feature_labels.expand(feature.size(0), -1), dim=-1, index=sim_indices)
    sim_weight = (sim_weight / knn_t).exp()

    # counts for each class
    one_hot_label = torch.zeros(feature.size(0) * knn_k, classes, device=sim_labels.device)
    # [B*K, C]
    one_hot_label = one_hot_label.scatter(dim=-1, index=sim_labels.view(-1, 1), value=1.0)
    # weighted score ---> [B, C]
    pred_scores = torch.sum(one_hot_label.view(feature.size(0), -1, classes) * sim_weight.unsqueeze(dim=-1), dim=1)

    pred_labels = pred_scores.argsort(dim=-1, descending=True)
    return pred_labels

with Timer():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Training with gpu: {device}.")
    # Initialize optimizer and loss criterion
    model = ModelMoCo(
            dim=128,
            K=4096,
            m=0.99,
            T=0.1,
            arch='resnet18',
            bn_splits=1,
            symmetric=True,)
    model = model.to(device)
    lr = 6e-3
    weight_decay = 5e-4
    optimizer = torch.optim.Adam(model.parameters(), lr, weight_decay=weight_decay)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=len(train_loader), eta_min=0,
                                                               last_epoch=-1)
    writer = SummaryWriter()
    logging.basicConfig(filename=os.path.join(writer.log_dir, 'training.log'), level=logging.DEBUG)
    criterion = torch.nn.CrossEntropyLoss().to(device)
    # Set number of training epochs
    epochs = 200
    log_every_n_epochs = 10
    logging.info(f"Start SimCLR training for {epochs} epochs.")
    logging.info(f"Training with gpu: {device}.")
    best_acc = 0
    for epoch_counter in range(epochs):
        loss_epoch = 0
        for images, _ in tqdm(train_loader):
            im_1, im_2 = images
            im_1, im_2 = im_1.cuda(non_blocking=True), im_2.cuda(non_blocking=True)

            loss, logits, labels = model(im_1, im_2)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            loss_epoch += loss.item()
            # scaler.scale(loss).backward()
            # scaler.step(self.optimizer)
            # scaler.update()
        avg_loss = loss_epoch / len(train_loader)
        # print(f"Epoch {epoch_counter}:\tLoss: {avg_loss}")
        # every log_every_n_epochs log epoch loss and accuracy
        if epoch_counter % log_every_n_epochs == 0:
            top1, top5 = accuracy(logits, labels, topk=(1, 5))
            test_acc_1_knn = test(model.encoder_q, memory_loader, test_loader, epoch_counter)
            writer.add_scalar('loss', avg_loss, global_step=epoch_counter)
            writer.add_scalar('acc/top1', top1[0], global_step=epoch_counter)
            writer.add_scalar('acc/top5', top5[0], global_step=epoch_counter)
            writer.add_scalar('learning_rate', scheduler.get_last_lr()[0], global_step=epoch_counter)
            writer.add_scalar('test_acc_1_knn', test_acc_1_knn, global_step=epoch_counter)
            if top1[0] > best_acc:
                best_acc = top1[0]
                save_checkpoint({
                    'epoch': epoch_counter,
                    'arch': 'resnet18',
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                }, is_best=True, filename=os.path.join(writer.log_dir, f'checkpoint_best.pth.tar'))


        # warmup for the first 10 epochs
        if epoch_counter >= 10:
            scheduler.step()
        logging.debug(f"Epoch: {epoch_counter}\tLoss: {loss}\tTop1 accuracy: {top1[0]}")

    logging.info("Training has finished.")
    # save model checkpoints
    checkpoint_name = 'checkpoint_{:04d}.pth.tar'.format(epochs)
    save_checkpoint({
        'epoch': epochs,
        'arch': 'resnet18',
        'state_dict': model.state_dict(),
        'optimizer': optimizer.state_dict(),
    }, is_best=False, filename=os.path.join(writer.log_dir, checkpoint_name))
    logging.info(f"Model checkpoint and metadata has been saved at {writer.log_dir}.")

Training with gpu: cuda.


100%|██████████| 97/97 [00:52<00:00,  1.84it/s]
Feature extracting: 100%|██████████| 98/98 [00:05<00:00, 17.34it/s]
Test Epoch: [0/200] Acc@1:21.98%: 100%|██████████| 20/20 [00:01<00:00, 11.59it/s]
100%|██████████| 97/97 [00:50<00:00,  1.92it/s]
 68%|██████▊   | 66/97 [00:34<00:15,  1.95it/s]

Load the model checkpoint and evaluate the learned representations using Linear Probing

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = models.resnet18(pretrained=False, num_classes=10).to(device)
# Load the checkpoint
checkpoint_path = 'runs/Sep26_17-15-26_cpsadmin-Z790-AORUS-ELITE-AX/checkpoint_best.pth.tar'
checkpoint = torch.load(checkpoint_path)
state_dict = checkpoint['state_dict']
# model.load_state_dict(state_dict)

for k in list(state_dict.keys()):
  if k.startswith('backbone.'):
    if k.startswith('backbone') and not k.startswith('backbone.fc'):
      # remove prefix
      state_dict[k[len("backbone."):]] = state_dict[k]
  del state_dict[k]
log = model.load_state_dict(state_dict, strict=False)
assert log.missing_keys == ['fc.weight', 'fc.bias']

In [None]:
# freeze all layers but the last fc
for name, param in model.named_parameters():
    if name not in ['fc.weight', 'fc.bias']:
        param.requires_grad = False

parameters = list(filter(lambda p: p.requires_grad, model.parameters()))
assert len(parameters) == 2  # fc.weight, fc.bias

In [None]:
from torchvision import datasets
def get_cifar10_data_loaders(download, shuffle=False, batch_size=256):
  train_dataset = datasets.CIFAR10('./data', train=True, download=download,
                                  transform=transforms.ToTensor())

  train_loader = DataLoader(train_dataset, batch_size=batch_size,
                            num_workers=8, drop_last=False, shuffle=shuffle)
  
  test_dataset = datasets.CIFAR10('./data', train=False, download=download,
                                  transform=transforms.ToTensor())
    
  test_loader = DataLoader(test_dataset, batch_size=2*batch_size,
                            num_workers=8, drop_last=False, shuffle=shuffle)
  return train_loader, test_loader


In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4, weight_decay=0.0008)
criterion = torch.nn.CrossEntropyLoss().to(device)
train_loader, test_loader = get_cifar10_data_loaders(download=True)

In [None]:
from utils import accuracy
epochs = 10
with Timer():
    for epoch in range(epochs):
        top1_train_accuracy = 0
        for counter, (x_batch, y_batch) in enumerate(train_loader):
            x_batch = x_batch.to(device)
            y_batch = y_batch.to(device)
            
            logits = model(x_batch)
            loss = criterion(logits, y_batch)
            top1 = accuracy(logits, y_batch, topk=(1,))
            top1_train_accuracy += top1[0]
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
        top1_train_accuracy /= (counter + 1)
        top1_accuracy = 0
        top5_accuracy = 0
        for counter, (x_batch, y_batch) in enumerate(test_loader):
            x_batch = x_batch.to(device)
            y_batch = y_batch.to(device)
            
            logits = model(x_batch)
            
            top1, top5 = accuracy(logits, y_batch, topk=(1,5))
            top1_accuracy += top1[0]
            top5_accuracy += top5[0]
        
        top1_accuracy /= (counter + 1)
        top5_accuracy /= (counter + 1)
        print(f"Epoch {epoch}:\tTrain Accuracy: {top1_train_accuracy.item():.2f}\tTest Accuracy: {top1_accuracy.item():.2f}\tTest Top-5 Accuracy: {top5_accuracy.item():.2f}")
  
  

### Train a ResNet18 model from scratch on CIFAR-10 using the sane augmentation strategy as SimCLR  

In [None]:
from torchvision import datasets
def get_cifar10_data_loaders(download, shuffle=False, batch_size=256):
  train_dataset = datasets.CIFAR10('./data', train=True, download=download,
                                  transform=transforms.ToTensor())

  train_loader = DataLoader(train_dataset, batch_size=batch_size,
                            num_workers=0, drop_last=False, shuffle=shuffle)
  
  test_dataset = datasets.CIFAR10('./data', train=False, download=download,
                                  transform=transforms.ToTensor())

  test_loader = DataLoader(test_dataset, batch_size=2*batch_size,
                            num_workers=10, drop_last=False, shuffle=shuffle)
  return train_loader, test_loader


In [None]:
from torchvision.models import resnet18
model = resnet18(pretrained=False, num_classes=10).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4, weight_decay=0.0008)
criterion = torch.nn.CrossEntropyLoss().to(device)
train_loader, test_loader = get_cifar10_data_loaders(download=True)

In [None]:
from utils import accuracy
epochs = 10
with Timer():
    for epoch in range(epochs):
        top1_train_accuracy_sup = 0
        for counter, (x_batch, y_batch) in enumerate(train_loader):
            x_batch = x_batch.to(device)
            y_batch = y_batch.to(device)
            
            logits = model(x_batch)
            loss = criterion(logits, y_batch)
            top1 = accuracy(logits, y_batch, topk=(1,))
            top1_train_accuracy_sup += top1[0]
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
        top1_train_accuracy_sup /= (counter + 1)
        top1_accuracy_sup = 0
        top5_accuracy_sup = 0
        for counter, (x_batch, y_batch) in enumerate(test_loader):
            x_batch = x_batch.to(device)
            y_batch = y_batch.to(device)
            
            logits = model(x_batch)
            
            top1, top5 = accuracy(logits, y_batch, topk=(1,5))
            top1_accuracy_sup += top1[0]
            top5_accuracy_sup += top5[0]
        
        top1_accuracy_sup /= (counter + 1)
        top5_accuracy_sup /= (counter + 1)
        print(f"Epoch {epoch}:\tTrain Accuracy: {top1_train_accuracy_sup.item():.2f}\tTest Accuracy: {top1_accuracy_sup.item():.2f}\tTest Top-5 Accuracy: {top5_accuracy_sup.item():.2f}")
            

In [None]:
model = resnet18(pretrained=True).to(device)
# overwrite the last fc layer
model.fc = nn.Linear(512, 10).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=3e-4, weight_decay=0.0008)
criterion = torch.nn.CrossEntropyLoss().to(device)
train_loader, test_loader = get_cifar10_data_loaders(download=True)

In [None]:
from utils import accuracy
epochs = 10
with Timer():
    for epoch in range(epochs):
        top1_train_accuracy_sup_pre = 0
        for counter, (x_batch, y_batch) in enumerate(train_loader):
            x_batch = x_batch.to(device)
            y_batch = y_batch.to(device)
            
            logits = model(x_batch)
            loss = criterion(logits, y_batch)
            top1 = accuracy(logits, y_batch, topk=(1,))
            top1_train_accuracy_sup_pre += top1[0]
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
        top1_train_accuracy_sup_pre /= (counter + 1)
        top1_accuracy_sup_pre = 0
        top5_accuracy_sup_pre = 0
        for counter, (x_batch, y_batch) in enumerate(test_loader):
            x_batch = x_batch.to(device)
            y_batch = y_batch.to(device)
            
            logits = model(x_batch)
            
            top1, top5 = accuracy(logits, y_batch, topk=(1,5))
            top1_accuracy_sup_pre += top1[0]
            top5_accuracy_sup_pre += top5[0]
        
        top1_accuracy_sup_pre /= (counter + 1)
        top5_accuracy_sup_pre /= (counter + 1)
        # print every 10 epochs
        print(f"Epoch {epoch}:\tTrain Accuracy: {top1_train_accuracy_sup_pre.item():.2f}\tTest Accuracy: {top1_accuracy_sup_pre.item():.2f}\tTest Top-5 Accuracy: {top5_accuracy_sup_pre.item():.2f}")
            

In [None]:
# print the results of the SimCLR model and the supervised model in a table format
print(f"{'Model':<25}{'Train Accuracy':<20}{'Test Accuracy':<20}{'Test Top-5 Accuracy':<20}")
print(f"{'SimCLR':<25}{top1_train_accuracy.item():<20.2f}{top1_accuracy.item():<20.2f}{top5_accuracy.item():<20.2f}")
print(f"{'Supervised':<25}{top1_train_accuracy_sup.item():<20.2f}{top1_accuracy_sup.item():<20.2f}{top5_accuracy_sup.item():<20.2f}")
print(f"{'Supervised Pretrained':<25}{top1_train_accuracy_sup_pre.item():<20.2f}{top1_accuracy_sup_pre.item():<20.2f}{top5_accuracy_sup_pre.item():<20.2f}")