<a href="https://colab.research.google.com/github/naru289/Assignment-37/blob/main/SimCLR(Ungraded).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

SimCLR is a framework for constrastive learning of visual representations by maximizing agreement between differently augmented views of the same data example via a constrastive loss in the latent space.

These visual representations are vectors on which linear classifiers can be trained to solve problems like image classification. We know that we can learn these visual representations by training deep learning models like ResNet on labeled datasets like ImageNet.


This notebook contains a PyTorch implementation of the paper [A Simple Framework for Contrastive Learning of Visual Representations](https://arxiv.org/abs/2002.05709) by chen etal.



### Importing the required packages

In [None]:
import numpy as np
from tqdm import tqdm_notebook as tqdm
from PIL import Image

import torch
import torch.nn as nn
from torch.optim import Adam
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as tfs
from torchvision.datasets import *
from torchvision.models import *

import warnings
warnings.filterwarnings("ignore")

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

### Define the Transformations

**Data augmentation module:** This module transforms any given data example stochastically generating two correlated views of the same example, denoted by $x_i$ and $x_j$. Here the authors used three simple augmentations: random cropping followed by reseizing to the original size, random color distortions, and random Gaussian blur.

In [None]:
tf_tr = tfs.Compose([
    tfs.RandomResizedCrop(32),
    tfs.RandomHorizontalFlip(),
    tfs.ColorJitter(0.5, 0.5, 0.5, 0.5),
    tfs.ToTensor(),
    tfs.Normalize(mean=[0.485, 0.456, 0.406],
                  std=[0.229, 0.224, 0.225])
])

tf_de = tfs.Compose([
    tfs.Resize(32),
    tfs.ToTensor(),
    tfs.Normalize(mean=[0.485, 0.456, 0.406],
                  std=[0.229, 0.224, 0.225])
])

tf_te = tfs.Compose([
    tfs.Resize(32),
    tfs.ToTensor(),
    tfs.Normalize(mean=[0.485, 0.456, 0.406],
                  std=[0.229, 0.224, 0.225])
])

### Loading the CIFAR-10 dataset

In [None]:
class CustomCIFAR10(CIFAR10):
    def __init__(self, **kwds):
        super().__init__(**kwds)

    def __getitem__(self, idx):
        if not self.train:
            return super().__getitem__(idx)

        img = self.data[idx]
        img = Image.fromarray(img).convert('RGB')
        imgs = [self.transform(img), self.transform(img)]
        return torch.stack(imgs)

In [None]:
# Download the dataset and apply the transformations
ds_tr = CustomCIFAR10(root='data', train=True, transform=tf_tr, download=True)
ds_de = CIFAR10(root='data', train=True, transform=tf_de, download=True)
ds_te = CIFAR10(root='data', train=False, transform=tf_te, download=True)

**Batch Size:** The authors experimented for a batch size N ranging from 256 to 8192. For a batch size of 8192 gives 16382 negative examples per positive pair from both augmentation views. Given the fact that SGD/Momentum doesn’t tend to work well beyond a given batch size, the authors used LARS optimizer for all batch sizes. With 128 TPU v3 cores, training a ResNet-50 with a batch size of 4096 for 100 epochs takes ~1.5 hours.

You can also increase the batch size to 256

In [None]:
# Load the dataset using dataloader
dl_tr = DataLoader(ds_tr, batch_size=128, shuffle=True)
dl_de = DataLoader(ds_de, batch_size=128, shuffle=True)
dl_te = DataLoader(ds_te, batch_size=128, shuffle=False)

### Loading the ResNet50 Model

**Base Encoder**: ResNet-50 is used as the base neural network encoder for extracting representation vectors from the augmented data examples. The output of the last average pooling layer used for extracting representations.

Projection Head: A small neural network, MLP with one hidden layer, is used to map the representations from the base encoder to 128-dimensional latent space where contrastive loss is applied. ReLU is the activation function used in this projection head.

In [None]:
model = resnet50(pretrained=False)
model.conv1 = nn.Conv2d(3, 64, 3, 1, 1, bias=False)
model.maxpool = nn.Identity()

In [None]:
ch = model.fc.in_features
model.fc = nn.Sequential(nn.Linear(ch, ch),
                           nn.ReLU(),
                           nn.Linear(ch, ch))
model.to(device)
model.train()

### Define the Contrastive loss Function

**Contrastive Loss Function:** Given a set of examples including a positive pair of examples ($x_i$ and $x_j$), the contrastive prediction task aims to identify $x_j$ in the given set for a given $x_i$.

In [None]:
#  For comparing the representations produced by the projection head, we use cosine similarity which is defined as below
def pair_cosine_similarity(x, eps=1e-8):
    n = x.norm(p=2, dim=1, keepdim=True)
    return (x @ x.t()) / (n * n.t()).clamp(min=eps)

def nt_xent(x, t=0.5):
    x = pair_cosine_similarity(x)
    x = torch.exp(x / t)
    idx = torch.arange(x.size()[0])
    # Put positive pairs on the diagonal
    idx[::2] += 1
    idx[1::2] -= 1
    x = x[idx]
    # subtract the similarity of 1 from the numerator
    x = x.diag() / (x.sum(0) - torch.exp(torch.tensor(1 / t)))
    return -torch.log(x.mean())

In [None]:
# Defining the optimizer
optimizer = Adam(model.parameters(), lr=0.005)

In [None]:
 import torch
 torch.cuda.empty_cache()

### Train the model


In [None]:
!mkdir checkpoint

In [None]:
model.train()
PATH = '/content/checkpoint'
for i in range(20): # Change the range number to train for 100 epochs
    c, s = 0, 0
    pBar = tqdm(dl_tr)
    for data in pBar:
        d = data.size()
        x = data.view(d[0]*2, d[2], d[3], d[4]).to(device)
        optimizer.zero_grad()
        p = model(x)
        loss = nt_xent(p)
        s = ((s*c)+(float(loss)*len(p)))/(c+len(p))
        c += len(p)
        pBar.set_description('Train: '+str(round(float(s),3)))
        loss.backward()
        optimizer.step()
        if (i+1) % 10 == 0:
            torch.save(model.state_dict(),PATH+'cifar10-rn50-mlp-b256-t0.5-e'+str(i+1)+'.pt')

In [None]:
# Freeze the sequential model parameter and train only the classifier by changing the number of classes to 10 for cifar 10
for param in model.parameters():
    param.requires_grad = False

In [None]:
model.fc = nn.Linear(ch, len(ds_de.classes))
model.to(device)

In [None]:
# Initialize the optimizer with a different learning rate
optimizer = Adam(model.parameters(), lr=0.003)
criterion = nn.CrossEntropyLoss()

In [None]:
# Train the model again and do the forwhard pass and update the weights
model.train()
for i in range(5):
    c, s = 0, 0
    pBar = tqdm(dl_de)
    for data in pBar:
        x, y = data[0].to(device), data[1].to(device)
        optimizer.zero_grad()
        p = model(x)
        loss = criterion(p, y)
        s = ((s*c)+(float(loss)*len(p)))/(c+len(p))
        c += len(p)
        pBar.set_description('Train: '+str(round(float(s),3)))
        loss.backward()
        optimizer.step()

In [None]:
optimizer = Adam(model.parameters(), lr=0.0001)
criterion = nn.CrossEntropyLoss()

In [None]:
# Training the model with a learning rate of 0.0001
model.train()
for i in range(5):
    c, s = 0, 0
    pBar = tqdm(dl_de)
    for data in pBar:
        x, y = data[0].to(device), data[1].to(device)
        optimizer.zero_grad()
        p = model(x)
        loss = criterion(p, y)
        s = ((s*c)+(float(loss)*len(p)))/(c+len(p))
        c += len(p)
        pBar.set_description('Train: '+str(round(float(s),3)))
        loss.backward()
        optimizer.step()

### Test the model

In [None]:
model.eval()
c, s = 0, 0
pBar = tqdm(dl_te)
for data in pBar:
    x, y, = data[0].to(device), data[1].to(device)
    p = model(x)
    loss = criterion(p, y)
    s = ((s*c)+(float(loss)*len(p)))/(c+len(p))
    c += len(p)
    pBar.set_description('Test: '+str(round(float(s),3)))

In [None]:
model.eval()
y_pred, y_true = [], []
pBar = tqdm(dl_te)
for data in pBar:
    x, y = data[0].to(device), data[1].to(device)
    p = model(x)
    y_pred.append(p.cpu().detach().numpy())
    y_true.append(y.cpu().detach().numpy())
y_pred = np.concatenate(y_pred)
y_true = np.concatenate(y_true)

In [None]:
(y_true == y_pred.argmax(axis=1)).mean()