# SimCLR Implementation and Evaluation on CIFAR-10

This notebook implements the SimCLR algorithm, trains it on the CIFAR-10 dataset, and evaluates the learned representations using Linear Probing and K-Nearest Neighbors (KNN) classification.


In [1]:
# Import necessary libraries
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from sklearn.neighbors import KNeighborsClassifier
import numpy as np
import torch.nn.functional as F
import torchvision.models as models

# Importing necessary libraries and modules for the implementation.

### Execution Timers

In [2]:

# Flag to enable or disable timers
enable_timers = True

import time

class Timer:
    def __enter__(self):
        if enable_timers:
            self.start = time.time()
        return self

    def __exit__(self, *args):
        if enable_timers:
            self.end = time.time()
            self.interval = self.end - self.start
            print(f"Elapsed time: {self.interval:.2f} seconds")
    

# Importing necessary libraries and modules for the implementation.

## Load CIFAR-10 Dataset

Load the CIFAR-10 training and test datasets.


In [3]:
from data_aug.contrastive_learning_dataset import ContrastiveLearningDataset

dataset = ContrastiveLearningDataset(root_folder='data')
train_dataset = dataset.get_dataset('cifar10', 2)
train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=512, shuffle=True,
        num_workers=8, pin_memory=True, drop_last=True)


Files already downloaded and verified


## Define SimCLR Encoder and Projection Head

Create the encoder model and projection head using ResNet18 as the base architecture.


In [4]:
class ResNetSimCLR(nn.Module):

    def __init__(self, base_model, out_dim):
        super(ResNetSimCLR, self).__init__()
        self.resnet_dict = {"resnet18": models.resnet18(pretrained=False, num_classes=out_dim),
                            "resnet50": models.resnet50(pretrained=False, num_classes=out_dim)}

        self.backbone = self._get_basemodel(base_model)
        dim_mlp = self.backbone.fc.in_features

        # add mlp projection head
        self.backbone.fc = nn.Sequential(nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.backbone.fc)

    def _get_basemodel(self, model_name):
        try:
            model = self.resnet_dict[model_name]
        except:
            raise ("Invalid model name. Check the config file and pass one of: resnet18 or resnet50")

        return model

    def forward(self, x):
        return self.backbone(x)

## Define Contrastive Loss

Implement the contrastive loss function used by SimCLR.


In [5]:
def info_nce_loss(features, temperature=0.5):
        batch_size = features.shape[0] // 2 # 2 views per batch
        
        labels = torch.cat([torch.arange(batch_size) for i in range(2)], dim=0)
        labels = (labels.unsqueeze(0) == labels.unsqueeze(1)).float()
        device = (torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu'))
        labels = labels.to(device)

        features = F.normalize(features, dim=1)

        similarity_matrix = torch.matmul(features, features.T)
        # assert similarity_matrix.shape == (
        #     self.args.n_views * self.args.batch_size, self.args.n_views * self.args.batch_size)
        # assert similarity_matrix.shape == labels.shape

        # discard the main diagonal from both: labels and similarities matrix
        mask = torch.eye(labels.shape[0], dtype=torch.bool).to(device)
        labels = labels[~mask].view(labels.shape[0], -1)
        similarity_matrix = similarity_matrix[~mask].view(similarity_matrix.shape[0], -1)
        # assert similarity_matrix.shape == labels.shape

        # select and combine multiple positives
        positives = similarity_matrix[labels.bool()].view(labels.shape[0], -1)

        # select only the negatives
        negatives = similarity_matrix[~labels.bool()].view(similarity_matrix.shape[0], -1)

        logits = torch.cat([positives, negatives], dim=1)
        labels = torch.zeros(logits.shape[0], dtype=torch.long).to(device)

        logits = logits / temperature
        return logits, labels

## Training SimCLR

Train the SimCLR model using the contrastive loss and augmented image pairs from CIFAR-10.


In [None]:
from torch.utils.tensorboard import SummaryWriter
import os
from tqdm import tqdm
import logging
from utils import accuracy, save_checkpoint

# Create a directory to save the model weights
save_dir = "saved_models"
os.makedirs(save_dir, exist_ok=True)

# Open the log file
log_file = "training_log.txt"
with open(log_file, 'a') as f:  # 'a' means append mode
    f.write("Training started...\n")
    
with Timer():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Training with gpu: {device}.")
    # Initialize optimizer and loss criterion
    model = ResNetSimCLR(base_model='resnet18', out_dim=128)
    model = model.to(device)
    lr = 3e-4
    weight_decay = 1e-4
    optimizer = torch.optim.Adam(model.parameters(), lr, weight_decay=weight_decay)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=len(train_loader), eta_min=0,
                                                           last_epoch=-1)
    writer = SummaryWriter()
    logging.basicConfig(filename=os.path.join(writer.log_dir, 'training.log'), level=logging.DEBUG)
    criterion = torch.nn.CrossEntropyLoss().to(device)
    # Set number of training epochs
    epochs = 200
    log_every_n_epochs = 1
    logging.info(f"Start SimCLR training for {epochs} epochs.")
    logging.info(f"Training with gpu: {device}.")
    best_acc = 0
    for epoch_counter in range(epochs):
        loss_epoch = 0
        for images, _ in tqdm(train_loader):
            images = torch.cat(images, dim=0)

            images = images.to(device)

            # with autocast(enabled=fp16_precision):
            features = model(images)
            logits, labels = info_nce_loss(features)
            loss = criterion(logits, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            loss_epoch += loss.item()
            # scaler.scale(loss).backward()
            # scaler.step(self.optimizer)
            # scaler.update()
        avg_loss = loss_epoch / len(train_loader)
        print(f"Epoch {epoch_counter}:\tLoss: {avg_loss}")
        # every log_every_n_epochs log epoch loss and accuracy
        if epoch_counter % log_every_n_epochs == 0:
            top1, top5 = accuracy(logits, labels, topk=(1, 5))
            writer.add_scalar('loss', avg_loss, global_step=epoch_counter)
            writer.add_scalar('acc/top1', top1[0], global_step=epoch_counter)
            writer.add_scalar('acc/top5', top5[0], global_step=epoch_counter)
            writer.add_scalar('learning_rate', scheduler.get_last_lr()[0], global_step=epoch_counter)
            if top1[0] > best_acc:
                best_acc = top1[0]
                save_checkpoint({
                    'epoch': epoch_counter,
                    'arch': 'resnet18',
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                }, is_best=True, filename=os.path.join(writer.log_dir, f'checkpoint_best.pth.tar'))


        # warmup for the first 10 epochs
        if epoch_counter >= 10:
            scheduler.step()
        logging.debug(f"Epoch: {epoch_counter}\tLoss: {loss}\tTop1 accuracy: {top1[0]}")

    logging.info("Training has finished.")
    # save model checkpoints
    checkpoint_name = 'checkpoint_{:04d}.pth.tar'.format(epochs)
    save_checkpoint({
        'epoch': epochs,
        'arch': 'resnet18',
        'state_dict': model.state_dict(),
        'optimizer': optimizer.state_dict(),
    }, is_best=False, filename=os.path.join(writer.log_dir, checkpoint_name))
    logging.info(f"Model checkpoint and metadata has been saved at {writer.log_dir}.")

Training with gpu: cuda.


100%|██████████| 97/97 [00:12<00:00,  7.80it/s]


Epoch 0:	Loss: 6.509977930599881


100%|██████████| 97/97 [00:12<00:00,  7.73it/s]


Epoch 1:	Loss: 6.301587645540533


100%|██████████| 97/97 [00:13<00:00,  7.33it/s]


Epoch 2:	Loss: 6.224259858278884


100%|██████████| 97/97 [00:14<00:00,  6.58it/s]


Epoch 3:	Loss: 6.174492098621486


100%|██████████| 97/97 [00:13<00:00,  7.06it/s]


Epoch 4:	Loss: 6.143865600074689


100%|██████████| 97/97 [00:13<00:00,  7.08it/s]


Epoch 5:	Loss: 6.105672202159449


100%|██████████| 97/97 [00:13<00:00,  7.13it/s]


Epoch 6:	Loss: 6.091826984562824


100%|██████████| 97/97 [00:13<00:00,  6.95it/s]


Epoch 7:	Loss: 6.06757251995126


100%|██████████| 97/97 [00:13<00:00,  6.95it/s]


Epoch 8:	Loss: 6.061101736481657


100%|██████████| 97/97 [00:13<00:00,  6.97it/s]


Epoch 9:	Loss: 6.044455862536873


100%|██████████| 97/97 [00:13<00:00,  7.20it/s]


Epoch 10:	Loss: 6.037745524927513


100%|██████████| 97/97 [00:13<00:00,  7.40it/s]


Epoch 11:	Loss: 6.016408045267322


100%|██████████| 97/97 [00:12<00:00,  7.71it/s]


Epoch 12:	Loss: 6.009699826388015


100%|██████████| 97/97 [00:12<00:00,  7.72it/s]


Epoch 13:	Loss: 6.001906611255763


100%|██████████| 97/97 [00:12<00:00,  7.56it/s]


Epoch 14:	Loss: 5.995030737414803


100%|██████████| 97/97 [00:12<00:00,  7.56it/s]


Epoch 15:	Loss: 5.988785011252177


100%|██████████| 97/97 [00:12<00:00,  7.56it/s]


Epoch 16:	Loss: 5.983934127178388


100%|██████████| 97/97 [00:12<00:00,  7.63it/s]


Epoch 17:	Loss: 5.980731211986738


100%|██████████| 97/97 [00:12<00:00,  7.63it/s]


Epoch 18:	Loss: 5.968473051012177


100%|██████████| 97/97 [00:12<00:00,  7.76it/s]


Epoch 19:	Loss: 5.956287059587302


100%|██████████| 97/97 [00:13<00:00,  7.09it/s]


Epoch 20:	Loss: 5.954351872512975


100%|██████████| 97/97 [00:12<00:00,  7.57it/s]


Epoch 21:	Loss: 5.952128793775421


100%|██████████| 97/97 [00:12<00:00,  7.72it/s]


Epoch 22:	Loss: 5.9355814039092705


100%|██████████| 97/97 [00:12<00:00,  7.60it/s]


Epoch 23:	Loss: 5.9270596799162245


100%|██████████| 97/97 [00:12<00:00,  7.74it/s]


Epoch 24:	Loss: 5.922943592071533


100%|██████████| 97/97 [00:13<00:00,  7.39it/s]


Epoch 25:	Loss: 5.923323405157659


100%|██████████| 97/97 [00:13<00:00,  6.94it/s]


Epoch 26:	Loss: 5.915223298613558


100%|██████████| 97/97 [00:13<00:00,  7.14it/s]


Epoch 27:	Loss: 5.906255436926773


100%|██████████| 97/97 [00:13<00:00,  7.11it/s]


Epoch 28:	Loss: 5.90319547948149


100%|██████████| 97/97 [00:14<00:00,  6.85it/s]


Epoch 29:	Loss: 5.894468966218614


100%|██████████| 97/97 [00:12<00:00,  7.47it/s]


Epoch 30:	Loss: 5.891392462032357


100%|██████████| 97/97 [00:13<00:00,  7.03it/s]


Epoch 31:	Loss: 5.883002807184593


100%|██████████| 97/97 [00:14<00:00,  6.60it/s]


Epoch 32:	Loss: 5.884909639653471


100%|██████████| 97/97 [00:13<00:00,  6.96it/s]


Epoch 33:	Loss: 5.87393271062792


100%|██████████| 97/97 [00:13<00:00,  7.18it/s]


Epoch 34:	Loss: 5.873671467771235


100%|██████████| 97/97 [00:13<00:00,  7.40it/s]


Epoch 35:	Loss: 5.876488434899714


100%|██████████| 97/97 [00:13<00:00,  7.33it/s]


Epoch 36:	Loss: 5.868537160539136


100%|██████████| 97/97 [00:12<00:00,  7.47it/s]


Epoch 37:	Loss: 5.8613553194655585


100%|██████████| 97/97 [00:12<00:00,  7.76it/s]


Epoch 38:	Loss: 5.857223132221969


100%|██████████| 97/97 [00:13<00:00,  7.42it/s]


Epoch 39:	Loss: 5.853782663640287


100%|██████████| 97/97 [00:13<00:00,  7.41it/s]


Epoch 40:	Loss: 5.853663218390081


100%|██████████| 97/97 [00:13<00:00,  7.32it/s]


Epoch 41:	Loss: 5.8502272674717855


100%|██████████| 97/97 [00:12<00:00,  7.78it/s]


Epoch 42:	Loss: 5.840807668941537


100%|██████████| 97/97 [00:12<00:00,  7.49it/s]


Epoch 43:	Loss: 5.839038735812473


100%|██████████| 97/97 [00:12<00:00,  7.46it/s]


Epoch 44:	Loss: 5.840348587822668


100%|██████████| 97/97 [00:13<00:00,  7.25it/s]


Epoch 45:	Loss: 5.830540637380069


100%|██████████| 97/97 [00:13<00:00,  7.44it/s]


Epoch 46:	Loss: 5.829489978318362


 72%|███████▏  | 70/97 [00:09<00:02,  9.52it/s]