In [None]:
%cd /content/drive/MyDrive/self-supervised

/content/drive/MyDrive/self-supervised


In [None]:
# pytorch lightning is an easy to use  wrapper for pytorch.
%%capture
!pip install pytorch-lightning

In [None]:
import cv2
import pickle
import numpy as np
import pandas as pd
import torch
import random
import torch.nn as nn
import torchvision as tv
from PIL import Image
import pytorch_lightning as pl
from pathlib import Path
from torchvision import transforms
import torchvision.transforms as tt
from torch.utils.data.dataloader import DataLoader
import matplotlib.pyplot as plt
import torchvision.transforms.functional as TF
%matplotlib inline

### Load CIFAR100 dataset

In [None]:
 #dataset normalization and necessary transoformations.
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

normalize = tt.Normalize(mean=[0.5074,0.4867,0.4411],
                                     std= [0.2011,0.1987,0.2025])

train_transform = tt.Compose([
                                tt.RandomHorizontalFlip(0.5),
                                 tt.RandomCrop(32,padding=4,padding_mode="reflect"),
                                 tt.ToTensor(),
                                 normalize
                                   ])
                             
test_transform = tt.Compose([ 
                               tt.ToTensor(),
                               normalize
  ])

In [None]:
# download the dataset
CIFAR100_train = tv.datasets.CIFAR100(download = True, root ='./data', train=True, transform = train_transform)
CIFAR100_test = tv.datasets.CIFAR100(download = True, root ='./data', train=False, transform = test_transform)


Files already downloaded and verified
Files already downloaded and verified


In [None]:
# lets verify the downloaded data and its shape (32,32)

image, label = next(iter(CIFAR100_train))
print(image.shape)
print("train data  = ", len(CIFAR100_train))
print("test data  = ", len(CIFAR100_test))


torch.Size([3, 32, 32])
train data  =  50000
test data  =  10000


In [None]:
image

tensor([[[ 2.4495,  2.4495,  2.4495,  ...,  0.5385,  0.9090,  0.5385],
         [ 2.4495,  2.4495,  2.4495,  ...,  0.7335,  1.2210,  1.6110],
         [ 2.4495,  2.4495,  2.4495,  ...,  1.1625,  1.2405,  1.7670],
         ...,
         [ 1.0065,  1.0065,  1.2015,  ..., -1.2361, -0.8461, -0.5926],
         [ 0.4605,  0.1875,  0.7530,  ..., -1.0411, -0.5731, -0.2416],
         [ 0.5190,  0.4020,  0.6945,  ..., -1.0996, -0.4756, -0.1050]],

        [[ 2.5833,  2.5833,  2.5833,  ...,  0.7873,  0.8268,  0.2150],
         [ 2.5833,  2.5833,  2.5833,  ...,  1.1228,  1.5767,  1.8136],
         [ 2.5833,  2.5833,  2.5833,  ...,  1.4781,  1.6557,  2.1294],
         ...,
         [ 1.7938,  1.7938,  1.9715,  ..., -1.5613, -1.1468, -0.8508],
         [ 1.2610,  1.0044,  1.6162,  ..., -1.3245, -0.8903, -0.4758],
         [ 1.2807,  1.2018,  1.5175,  ..., -1.3442, -0.7324, -0.2982]],

        [[ 2.7600,  2.7600,  2.7600,  ...,  0.1069,  0.3199, -0.2611],
         [ 2.7600,  2.7600,  2.7600,  ...,  0

In [None]:
 # pass it to DataLoader instance
 class_labels = CIFAR100_train.classes
 print("Total class labels = ",len(class_labels))

Total class labels =  100


In [None]:
# perform rotation over original images to predict the rotation
class RotationalTransform:
  """Apply rotation over the datset to train the model in an unsupervised way
  
  Attributes
  -----------
  angle : int
            the angle by which inage is supposed to be rotated
  """
  def __init__(self, angle):
    self.angle = angle

  def __call__(self, x):
    return TF.rotate(x, self.angle)

class VerticalFlip:
  """flips the image vertically"""
  def __init__(self):
    pass
  def __call__(self, x):
    return TF.vflip(x)

class HorizontalFlip:
  """flips the image horizontally by """
  def __init__(self):
    pass
  def __call__(self, x):
    return TF.hflip(x)

In [None]:
# len(CIFAR100_test)

In [None]:
from PIL import Image
from matplotlib import cm


In [None]:
class SelfSupervisedDataset(object):
  """Data encapulation for self-supervised learning

  Attributes
  -----------
  path : Path
            path to the dataset
  class_transform : list
            list of transformations to be appllied om the dataset
  to_tesnor : object
            converts the PIL image to pytorch tensor.           
  """
  def __init__(self, image_path=CIFAR100_train):
    self.path = image_path
    self.class_transforms = [RotationalTransform(0), RotationalTransform(90), 
                       RotationalTransform(180), RotationalTransform(270), 
                       HorizontalFlip(),VerticalFlip()]
    self.to_tensor = tt.Compose([tt.ToTensor()])                       
    self.classes = len(self.class_transforms)

  def __getitem__(self, idx):
    img, _ = self.path.__getitem__(idx)
    label = random.choice(range(0, self.classes))
    transformed_image =self.class_transforms[label](img)
    return transformed_image, label

  def __len__(self):
    return len(self.path)

In [None]:
 torch.cuda.empty_cache()

In [None]:
#initialize the model using pytorch-lightning wrapper
class self_supervised_model(pl.LightningModule):
  """PyTorch-lightning wrapper to wrap the training and testing phase

  Attributes
  -----------
  resent : CNN model
               model for feature extraction
  resnet.fc : sequential container
                  container for the fully connected layer
  batch_size : int
                  batch size of the data
  loss_fn : function
                loss function for the classifier
  hparams : dict
                dictionary of hyperparameters

  """
  def __init__(self, hparams=None, num_classes=6, batch_size=64):
    super().__init__()
    self.resnet = tv.models.resnet18(pretrained=False)
    self.resnet.fc = nn.Sequential(nn.Linear(512, num_classes))
    self.batch_size = batch_size
    self.loss_fn = nn.CrossEntropyLoss()
    if "lr" not in hparams:
      hparams["lr"] = 0.001
    self.hparams = hparams
    self.training_loss = []
    self.test_loss = []
    self.training_acc = []
    self.test_acc = []

  def forward(self, x):
      return self.resnet(x)
      
  def training_step(self, batch, batch_idx):
    inputs, targets = batch
    predictions = self(inputs)
    loss = self.loss_fn(predictions, targets)
    _, preds = torch.max(predictions, 1)
    train_acc = torch.sum(preds == targets.data) / (targets.shape[0] * 1.0)
    self.training_loss.append(loss)
    self.training_acc.append(train_acc)
    return {'loss': loss, 'train_acc':train_acc}

  def configure_optimizers(self):
    return torch.optim.Adam(self.parameters(), lr=self.hparams["lr"])

  def prepare_data(self):
    self.training_dataset = SelfSupervisedDataset()
    self.val_dataset = SelfSupervisedDataset(CIFAR100_test)

  def train_dataloader(self):
    return torch.utils.data.DataLoader(self.training_dataset, batch_size=self.batch_size, num_workers=2, shuffle=True)

  def val_dataloader(self):
    return torch.utils.data.DataLoader(self.val_dataset, batch_size=self.batch_size, num_workers=2)

  def validation_step(self, batch, batch_idx):
    inputs, targets = batch
    predictions = self(inputs)
    val_loss = self.loss_fn(predictions, targets)
    _, preds = torch.max(predictions, 1)
    acc = torch.sum(preds == targets.data) / (targets.shape[0] * 1.0)
    self.test_acc.append(acc)
    self.test_loss.append(val_loss)
    return {'val_loss': val_loss, 'val_acc': acc}

  def validation_epoch_end(self, outputs):
    avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
    avg_acc = torch.stack([x['val_acc'].float() for x in outputs]).mean()
    # logs = {'val_loss': avg_loss, 'val_acc': avg_acc}
    self.log('val_loss', avg_loss,on_epoch=True, prog_bar=True)
    self.log('val_acc',  avg_acc, on_epoch=True, prog_bar=True)




In [None]:
# model =  self_supervised_model({'lr': 0.001})

### Train Unsupervised on rotated dataset (Pretext task)

In [None]:
model = self_supervised_model({'lr': 0.001})

trainer = pl.Trainer(max_epochs=60, gpus=1)
trainer.fit(model)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type             | Params
---------------------------------------------
0 | resnet  | ResNet           | 11.2 M
1 | loss_fn | CrossEntropyLoss | 0     
---------------------------------------------
11.2 M    Trainable params
0         Non-trainable params
11.2 M    Total params
44.718    Total estimated model params size (MB)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…




1

In [None]:
trainer.save_checkpoint("selfsupervised_model.pth")

In [None]:
model = model.load_from_checkpoint("selfsupervised_model.pth")
#verify the model
model

self_supervised_model(
  (resnet): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=Tr

### Use the trained model for classification (Downstream layer)

In [None]:
## Define Dataloaders
train_loader = torch.utils.data.DataLoader(CIFAR100_train, batch_size = 16, num_workers = 2, shuffle = True)
test_loader = torch.utils.data.DataLoader(CIFAR100_test, batch_size = 16, num_workers = 2, shuffle = False)

In [None]:
 #load from checkpoint
 model = model.load_from_checkpoint("selfsupervised_model.pth")

 #rplace the last layer for our custom dataset 
 model.resnet.fc= nn.Linear(512, 100)

In [None]:
trainer = pl.Trainer(max_epochs = 5, gpus = 1)
trainer.fit(model, train_dataloader=train_loader, val_dataloaders=test_loader )

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type             | Params
---------------------------------------------
0 | resnet  | ResNet           | 11.2 M
1 | loss_fn | CrossEntropyLoss | 0     
---------------------------------------------
11.2 M    Trainable params
0         Non-trainable params
11.2 M    Total params
44.911    Total estimated model params size (MB)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…




1

In [None]:
def weights_init_He(m):
  classname = m.__class__.__name__
  # for every Linear layer in a model..
  if classname.find('Conv2d') != -1:
      # apply a uniform distribution to the weights and a bias=0
      n = m.weight.shape[1]*m.weight.shape[2]
      std = np.sqrt(2/n)
      m.weight.data.uniform_(0.0, std)

original_model = self_supervised_model({'lr':0.001}, num_classes=100)
original_model.apply(weights_init_He)

self_supervised_model(
  (resnet): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=Tr

In [None]:
## Lets train our original model without any pretext training


trainer  = pl.Trainer(max_epochs=10, gpus = 1)
trainer.fit(original_model, train_dataloader=train_loader, val_dataloaders=test_loader )


GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type             | Params
---------------------------------------------
0 | resnet  | ResNet           | 11.2 M
1 | loss_fn | CrossEntropyLoss | 0     
---------------------------------------------
11.2 M    Trainable params
0         Non-trainable params
11.2 M    Total params
44.911    Total estimated model params size (MB)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…




1

As we can see that when model was trained unsupervisedly and then trained on the custom dataset, we got the accuracy of  43.2% for 5 epochs. while when model was trained with He initialization, we get the accuracy of 31.4% for 10 epochs. 