# Self-Supervised Learning <a id="top"></a>

---
## Table of Contents

* [Self-Supervised Learning Overview](#ssl_overview)
    * [Core Concepts](#core_concepts)
    * [How It Works](#how_it_works)
* [Self-Supervised Learning Tutorial](#ssl_tutorial)
    * [Imports](#imports)
    * [Dataset Preparation](#dataset_prep)
    * [Model Architecture](#model_architecture)
    * [Projection Head](#projection_head)
    * [Simple Contrastive Learning of Representations](#simclr)
    * [Constrastive Loss Function](#constrastive_loss)
    * [Init the Model](#init_model)
    * [Training Loop](#training_loop)
* [Downstream Task (Image Classification)](#img_classification)
    * [Setup](#setup)
    * [Feature Extraction](#feature_extraction)
    * [Init the Classifier](#init_classifier)
    * [Train the Classifier](#train_classifier)
    * [Evaluate the Classifier](#eval_classifier) 

# Self-Supervised Learning Overview <a class="anchor" id="ssl_overview"></a>

Self-supervised learning (SSL) leverages unsupervised learning for tasks that conventionally require supervised learning. SSL has been gaining a lot of interests in recent years for its ability to learn from unlabeled data, reduce annotation costs, and facilitate transferable representations

Instead of relying on labeled datasets to understand semantic meanings, self-supervised models generate implicit labels from unstructured data. This enables the model to extract meaningful features from the data, allowing it to learn useful representations even without explicit labels.

SSL is particularly useful in fields like computer vision and natural language processing (NLP) where obtaining large amounts of labeled data can be challenging (i.e. anomaly detection).

A core technique in self-supervised learning is contrastive learning which focuses on maximizing the similarity between representations of similar data points and minimizing the similarity between dissimilar ones. Imagine showing your model two images: one of a cat and another of a dog. Contrastive learning encourages the model to create representations where the cat image's representation is closer to another cat image's representation than it is to the dog image's representation.

## How It Works <a class="anchor" id="how_it_works"></a>

- In supervised learning, ground truth labels are directly provided by human experts.
- In self-supervised learning, tasks are designed such that “ground truth” can be inferred from unlabeled data.
- SSL tasks fall into two categories:
  - Pretext Tasks: Train AI systems to learn meaningful representations of unstructured data. These learned representations can be subsequently used in downstream tasks.
  - Downstream Tasks: Reuse pre-trained models on new tasks, a technique known as "transfer learning"

# Self-Supervised Learning Tutorial <a class="anchor" id="ssl_tutorial"></a>

## Imports <a class="anchor" id="imports"></a>

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

## Dataset Preparation <a class="anchor" id="dataset_prep"></a>

For this tutorial, we'll use the CIFAR-10 dataset. You can download and load it using torchvision

In [2]:
train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transforms.ToTensor())
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

Files already downloaded and verified


## Model Architecture <a class="anchor" id="model_architecture"></a>

Define a simple convolutional neural network (CNN) as our base encoder

In [3]:
class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2),
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2),
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2)
        )

    def forward(self, x):
        return self.encoder(x)

## Projection Head <a class="anchor" id="projection_head"></a>

Add a projection head to project the encoded features into a lower-dimensional space

In [4]:
class ProjectionHead(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(ProjectionHead, self).__init__()
        self.projection_head = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, output_dim)
        )

    def forward(self, x):
        return self.projection_head(x)

## Simple Contrastive Learning of Representations <a class="anchor" id="simclr"></a>

Combine the encoder and projection head into the SimCLR model

In [5]:
class SimCLR(nn.Module):
    def __init__(self, encoder, projection_head):
        super(SimCLR, self).__init__()
        self.encoder = encoder
        self.projection_head = projection_head

    def forward(self, x):
        features = self.encoder(x)
        features = features.view(features.size(0), -1)  # Flatten the features
        projections = self.projection_head(features)
        return features, projections

## Contrastive Loss Function <a class="anchor" id="constrastive_loss"></a>

Define the contrastive loss function

In [6]:
class ContrastiveLoss(nn.Module):
    def __init__(self, temperature=0.5):
        super(ContrastiveLoss, self).__init__()
        self.temperature = temperature

    def forward(self, features, projections):
        bs = features.size(0)
        features = nn.functional.normalize(features, dim=1)
        similarity_matrix = torch.matmul(features, features.T) / self.temperature
        mask = torch.eye(bs, dtype=torch.bool).cuda()
        loss = F.cross_entropy(similarity_matrix, torch.arange(bs).cuda())
        return loss

## Init Model <a class="anchor" id="init_model"></a>

In [7]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

encoder = Encoder().to(device)
projection_head = ProjectionHead(256 * 4 * 4, 256, 128).to(device) # Update projection head input dimension
model = SimCLR(encoder, projection_head).to(device)
criterion = ContrastiveLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

## Training Loop <a class="anchor" id="training_loop"></a>

Define the training loop

In [8]:
num_epochs = 10
for epoch in range(num_epochs):
    total_loss = 0
    for batch in train_loader:
        images, _ = batch
        images = images.to(device)
        features, projections = model(images)
        loss = criterion(features, projections)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {total_loss/len(train_loader):.4f}")

Epoch [1/10], Loss: 3.9768
Epoch [2/10], Loss: 4.1543
Epoch [3/10], Loss: 4.1545
Epoch [4/10], Loss: 4.1531
Epoch [5/10], Loss: 4.1527
Epoch [6/10], Loss: 4.1546
Epoch [7/10], Loss: 4.1554
Epoch [8/10], Loss: 4.1553
Epoch [9/10], Loss: 4.1548
Epoch [10/10], Loss: 4.1566


# Downstream Task (Image Classification) <a class="anchor" id="img_classification"></a>

Simple linear classifier trained on top of the frozen encoder of your SimCLR model

## Setup <a class="anchor" id="setup"></a>

In [9]:
# Load your trained SimCLR model and set it to evaluation mode
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.eval()
model.to(device)

# Data Augmentation (Replace with your preferred transformations)
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.RandomHorizontalFlip(p=0.5),  # Random horizontal flip
    transforms.RandomAffine(degrees=10, translate=(0.1, 0.1)),  # Random rotation and shift
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

# Load CIFAR-10 test dataset
# test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transforms.ToTensor())
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

Files already downloaded and verified


### Feature Extraction <a class="anchor" id="feature_extraction"></a>

In [10]:
# Extract features from the test dataset using the encoder of SimCLR
def extract_features(data_loader, model):
  features = []
  labels = []
  for images, targets in data_loader:
    with torch.no_grad():
      features_batch, _ = model(images.to(device))
      features.append(features_batch)
      labels.append(targets)
  return torch.cat(features, dim=0), torch.cat(labels, dim=0)

# Extract features from the test dataset
test_features, test_labels = extract_features(test_loader, model)

### Init the Classifier <a class="anchor" id="init_classifier"></a>

In [11]:
# Define a simple linear classifier
class LinearClassifier(nn.Module):
  def __init__(self, input_dim, num_classes):
    super(LinearClassifier, self).__init__()
    self.fc = nn.Linear(input_dim, num_classes)

  def forward(self, x):
    return self.fc(x)

# Initialize the linear classifier (assuming the encoder output is 256 * 4 * 4)
classifier = LinearClassifier(input_dim=3072, num_classes=10).to(device)

# Hyperparameter Tuning (Experiment with different learning rates and epochs)
learning_rate = 0.0005  # Adjust based on experimentation
num_epochs = 20  # Adjust based on experimentation

# Define optimizer and loss function
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(classifier.parameters(), lr=learning_rate)

### Train the Classifier <a class="anchor" id="train_classifier"></a>

In [None]:
# Train the linear classifier
for epoch in range(num_epochs):
  classifier.train()
  for features, labels in train_loader:
    # Reshape features if necessary (same as in extract_features)
    features = features.view(features.size(0), -1).to(device)
    labels = labels.to(device)
    outputs = classifier(features)
    loss = criterion(outputs, labels)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

  print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item()}")

### Evaluate the Classifier <a class="anchor" id="eval_classifier"></a>

In [None]:
classifier.eval()
correct = 0
total = 0
with torch.no_grad():
    for images, labels in test_loader:
        images = images.view(images.size(0), -1).to(device)
        labels = labels.to(device)
        outputs = classifier(images)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f"Accuracy on the test set: {(100 * correct / total):.2f}%")

## Notes

This tutorial provides a basic implementation of contrastive learning with SimCLR. You can further experiment by adjusting hyperparameters, using different datasets, or exploring advanced techniques like data augmentations and different architectures.

**[Go to Top](#top)**