#### Implementation of SimCLR framework to carry out emotion classification on German Shepherds
- Date : February 17th 2024
- Author : Aarya Bhave
- Project : Dog_Emotion_Classification
  
This code implements the SimCLR framework to carry out self-supervised contrastive leanrning to detect emotions in German Shepherds.  
Currently Limited to German Shepherds only.  
PyTorch version '2.1.1'.

In [1]:
import numpy as np
import pandas as pd
import shutil, time, os, tqdm, requests, random, copy
import PIL
from PIL import Image

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, transforms, models

import matplotlib.pyplot as plt 
%matplotlib inline

from sklearn.manifold import TSNE

In [2]:
def set_seed(seed = 16):
    np.random.seed(seed)
    torch.manual_seed(seed)

##### Set up the DataSet class.
- 264 unlabeled images of German Shepherds.
- Processor threading based augmentation techniques for low latency.

In [3]:
from imutils import paths

class GSDDataset(torch.utils.data.Dataset):
    def __init__(self, transforms = None):
        self.root_dir = 'data/processed/YOLOCrops/German_Shepherd/dog'
        self.image_paths = list(paths.list_images(self.root_dir))
        self.transforms = transforms

    def __getitem__(self, index):
        sample = Image.open(self.image_paths[index])
        if self.transforms:
            sample = self.transforms(sample)
        return sample
    
    def __len__(self):
        return len(self.image_paths)
    

##### Set up the Data Transformations.
- MUST BE TINKERED WITH.
- Data augmentations are treated as hyper-parameters.
- Images will be resized to 224, 224.
- 18FEB2024 - Random Resize and Crop, Random Horizontal Flip, Random Color Jitters.

In [4]:
class Augment:
    def __call__(self, sample):
        my_transforms = transforms.Compose([transforms.RandomResizedCrop((224, 224), scale = (0.75, 1)),
                                         transforms.RandomHorizontalFlip(p=0.5),
                                         transforms.ColorJitter(brightness=(0.80,1.20),contrast=(0.75, 1.25),saturation=(0.75,1.25),hue=(-0.1,0.1))])
        sample = my_transforms(sample)
        return sample
    
augmentations = transforms.Compose([Augment(), transforms.ToTensor()])

##### Set up hyper-parameters.
- Epochs ~600.
- Batch Size ~256.
- Temperature ~0.1

In [5]:
BATCH_SIZE = 256
EPOCHS = 600

##### Set up the DataLoaders
- Issue : shuffle=True causes len() arguement to breakdown. !Fixed!

In [6]:
dataset = GSDDataset(transforms=augmentations)
dataloader = DataLoader(dataset = dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=8)

In [7]:
dataiter = iter(dataloader)
data = next(dataiter)
data.shape

torch.Size([256, 3, 224, 224])

##### Model
- Encoder -> Projection Head -> NT Xent Loss.
- Encoder ~ResNet50 18FEB2024
- Projection Head ~MultiLayerPerceptron with one hidden layer 18FEB2024
- NT Xent Loss Temperature ~0.2 18FEB2024

In [8]:
class Identity(nn.Module):
    def __init__(self):
        super(Identity, self).__init__()
    
    def forward(self, x):
        return x

class LinearLayer(nn.Module):
    def __init__(self,
                 in_features,
                 out_features,
                 use_bias = True,
                 use_bn = False,
                 **kwargs):
        super(LinearLayer, self).__init__(**kwargs)
        self.in_features = in_features
        self.out_features = out_features
        self.use_bias = use_bias
        self.use_bn = use_bn

        self.linear = nn.Linear(self.in_features, self.out_features, bias=self.use_bias and not self.use_bn)

        if self.use_bn:
            self.bn = nn.BatchNorm1d(self.out_features)

    def forward(self, x):
        x = self.linear(x)
        if self.use_bn:
            x = self.bn(x)
        return x
    
class ProjectionHead(nn.Module):
    def __init__(self,
                 in_features,
                 hidden_features,
                 out_features,
                 head_type = 'nonlinear',
                 **kwargs):
        super(ProjectionHead, self).__init__(**kwargs)
        self.in_features = in_features
        self.out_features = out_features
        self.hidden_features = hidden_features
        self.head_type = head_type

        if self.head_type == 'linear':
            self.layers = LinearLayer(self.in_features, self.out_features, False, True)
        elif self.head_type =='nonlinear':
            self.layers = nn.Sequential(LinearLayer(self.in_features, self.hidden_features, True, True),
                                        nn.ReLU(),
                                        LinearLayer(self.hidden_features, self.out_features, False, True))
        
    def forward(self, x):
        x = self.layers(x)
        return x
    
class PreModel(nn.Module):
    def __init__(self, base_model):
        super().__init__()
        self.base_model = base_model

        #PRETRAINED MODEL
        self.pretrained = models.resnet50(pretrained = True)

        self.pretrained.conv1 = nn.Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), bias=False)
        self.pretrained.maxpool = Identity()

        self.pretrained.fc = Identity()

        for p in self.pretrained.parameters():
            p.requires_grad = True

        self.projector = ProjectionHead(2048, 2048, 128)

    def forward(self, x):
        out = self.pretrained(x)
        xp = self.projector(torch.squeeze(out))
        return xp

In [9]:
model = PreModel('resnet50').to('cuda:0')

Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /home/aarya/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth
100%|██████████| 97.8M/97.8M [00:03<00:00, 30.4MB/s]


##### Loss Function 
- SimCLR Loss function picked up from Spijkervet/SimCLR on GitHub.

In [11]:
class SimCLR_Loss(nn.Module):
    def __init__(self, batch_size, temperature):
        super().__init__()
        self.batch_size = batch_size
        self.temperature = temperature

        self.mask = self.mask_correlated_samples(batch_size)
        self.criterion = nn.CrossEntropyLoss(reduction="sum")
        self.similarity_f = nn.CosineSimilarity(dim=2)

    def mask_correlated_samples(self, batch_size):
        N = 2 * batch_size
        mask = torch.ones((N, N), dtype=bool)
        mask = mask.fill_diagonal_(0)
        for i in range(batch_size):
            mask[i, batch_size + i] = 0
            mask[batch_size +i, i] = 0
        return mask
    
    def forward(self, z_i, z_j):
        N = 2 * self.batch_size

        z = torch.cat((z_i, z_j), dim=0)

        sim = self.similarity_f(z.unsqueeze(1), z.unsqueeze(0)) / self.temperature

        sim_i_j = torch.diag(sim, self.batch_size)
        sim_j_i = torch.diag(sim, -self.batch_size)
        positive_samples = torch.cat((sim_i_j, sim_j_i), dim=0).reshape(N, 1)
        negative_samples = sim[self.mask].reshape(N, -1)
        labels = torch.from_numpy(np.array([0]*N)).reshape(-1).to(positive_samples.device).long()
        
        logits = torch.cat((positive_samples, negative_samples), dim=1)
        loss = self.criterion(logits, labels)
        loss /= N
        
        return loss

##### Optimizer
- LARS
- Learning Rate - 0.2
- Linear warmup for 10 epochs, and then a cosine decay schedule.
- Picked up from Spijkervet/SimCLR on GitHub.

In [12]:
from torch.optim.optimizer import Optimizer, required
import re

EETA_DEFAULT = 0.001


class LARS(Optimizer):
    """
    Layer-wise Adaptive Rate Scaling for large batch training.
    Introduced by "Large Batch Training of Convolutional Networks" by Y. You,
    I. Gitman, and B. Ginsburg. (https://arxiv.org/abs/1708.03888)
    """

    def __init__(
        self,
        params,
        lr=required,
        momentum=0.9,
        use_nesterov=False,
        weight_decay=0.0,
        exclude_from_weight_decay=None,
        exclude_from_layer_adaptation=None,
        classic_momentum=True,
        eeta=EETA_DEFAULT,
    ):
        """Constructs a LARSOptimizer.
        Args:
        lr: A `float` for learning rate.
        momentum: A `float` for momentum.
        use_nesterov: A 'Boolean' for whether to use nesterov momentum.
        weight_decay: A `float` for weight decay.
        exclude_from_weight_decay: A list of `string` for variable screening, if
            any of the string appears in a variable's name, the variable will be
            excluded for computing weight decay. For example, one could specify
            the list like ['batch_normalization', 'bias'] to exclude BN and bias
            from weight decay.
        exclude_from_layer_adaptation: Similar to exclude_from_weight_decay, but
            for layer adaptation. If it is None, it will be defaulted the same as
            exclude_from_weight_decay.
        classic_momentum: A `boolean` for whether to use classic (or popular)
            momentum. The learning rate is applied during momeuntum update in
            classic momentum, but after momentum for popular momentum.
        eeta: A `float` for scaling of learning rate when computing trust ratio.
        name: The name for the scope.
        """

        self.epoch = 0
        defaults = dict(
            lr=lr,
            momentum=momentum,
            use_nesterov=use_nesterov,
            weight_decay=weight_decay,
            exclude_from_weight_decay=exclude_from_weight_decay,
            exclude_from_layer_adaptation=exclude_from_layer_adaptation,
            classic_momentum=classic_momentum,
            eeta=eeta,
        )

        super(LARS, self).__init__(params, defaults)
        self.lr = lr
        self.momentum = momentum
        self.weight_decay = weight_decay
        self.use_nesterov = use_nesterov
        self.classic_momentum = classic_momentum
        self.eeta = eeta
        self.exclude_from_weight_decay = exclude_from_weight_decay
        # exclude_from_layer_adaptation is set to exclude_from_weight_decay if the
        # arg is None.
        if exclude_from_layer_adaptation:
            self.exclude_from_layer_adaptation = exclude_from_layer_adaptation
        else:
            self.exclude_from_layer_adaptation = exclude_from_weight_decay

    def step(self, epoch=None, closure=None):
        loss = None
        if closure is not None:
            loss = closure()

        if epoch is None:
            epoch = self.epoch
            self.epoch += 1

        for group in self.param_groups:
            weight_decay = group["weight_decay"]
            momentum = group["momentum"]
            eeta = group["eeta"]
            lr = group["lr"]

            for p in group["params"]:
                if p.grad is None:
                    continue

                param = p.data
                grad = p.grad.data

                param_state = self.state[p]

                # TODO: get param names
                # if self._use_weight_decay(param_name):
                grad += self.weight_decay * param

                if self.classic_momentum:
                    trust_ratio = 1.0

                    # TODO: get param names
                    # if self._do_layer_adaptation(param_name):
                    w_norm = torch.norm(param)
                    g_norm = torch.norm(grad)

                    device = g_norm.get_device()
                    trust_ratio = torch.where(
                        w_norm.gt(0),
                        torch.where(
                            g_norm.gt(0),
                            (self.eeta * w_norm / g_norm),
                            torch.Tensor([1.0]).to(device),
                        ),
                        torch.Tensor([1.0]).to(device),
                    ).item()

                    scaled_lr = lr * trust_ratio
                    if "momentum_buffer" not in param_state:
                        next_v = param_state["momentum_buffer"] = torch.zeros_like(
                            p.data
                        )
                    else:
                        next_v = param_state["momentum_buffer"]

                    next_v.mul_(momentum).add_(scaled_lr, grad)
                    if self.use_nesterov:
                        update = (self.momentum * next_v) + (scaled_lr * grad)
                    else:
                        update = next_v

                    p.data.add_(-update)
                else:
                    raise NotImplementedError

        return loss

    def _use_weight_decay(self, param_name):
        """Whether to use L2 weight decay for `param_name`."""
        if not self.weight_decay:
            return False
        if self.exclude_from_weight_decay:
            for r in self.exclude_from_weight_decay:
                if re.search(r, param_name) is not None:
                    return False
        return True

    def _do_layer_adaptation(self, param_name):
        """Whether to do layer-wise learning rate adaptation for `param_name`."""
        if self.exclude_from_layer_adaptation:
            for r in self.exclude_from_layer_adaptation:
                if re.search(r, param_name) is not None:
                    return False
        return True

In [14]:
optimizer = LARS(
    [params for params in model.parameters() if params.requires_grad],
    lr=0.2,
    weight_decay=1e-6,
    exclude_from_weight_decay=["batch_normalization", "bias"],
)

# "decay the learning rate with the cosine decay schedule without restarts"
#SCHEDULER OR LINEAR EWARMUP
warmupscheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda epoch : (epoch+1)/10.0, verbose = True)

#SCHEDULER FOR COSINE DECAY
mainscheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, 500, eta_min=0.05, last_epoch=-1, verbose = True)

#LOSS FUNCTION
criterion = SimCLR_Loss(batch_size = 128, temperature = 0.5)

Adjusting learning rate of group 0 to 2.0000e-02.
Epoch 00000: adjusting learning rate of group 0 to 2.0000e-01.


In [15]:
def save_model(model, optimizer, scheduler, current_epoch, name):
    out = os.path.join('/content/saved_models/',name.format(current_epoch))

    torch.save({'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict':scheduler.state_dict()}, out)
