# PyTorch Lightning and Lightly for Contrastive Self-Supervised Learning with MoCo

PyTorch Lightning is a high-level framework built on top of PyTorch. Lightning is designed to abstract away the boilerplate code required for deep learning projects, allowing Lightning users to focus more on the research and less on the underlying engineering complexity.

Self-supervised learning (SSL) has emerged as a powerful strategy for scaling up the amount of data that can be leveraged for training machine learning models. In SSL labels are generated from unlabeled datasets by ingeniously leverages properties of the data itself. This facilitates the learning of rich representations without the need for manually annotating datasets. 

Contrastive learning is an SSL approach that teaches a model to distinguish between similar (positive) and dissimilar (negative) pairs of data points. It relies on creating embeddings in a way that similar or "positive" pairs are brought closer together, while dissimilar or "negative" pairs are pushed apart in the embedding space. This is achieved through a contrastive loss function, such as Noise Contrastive Estimation (NCE) or Triplet Loss.

Momentum Contrast (MoCo) is a specific instantiation of contrastive learning designed to address the challenge of having a large and consistent set of negative examples. MoCo achieves this by maintaining a dynamic dictionary of data samples, using a momentum-updated encoder. This strategy allows for a larger and more consistent set of negatives over batches, improving the quality of the learned representations. MoCo can be seen as enhancing the scalability and effectiveness of contrastive learning methods.

### Imports

In [None]:
# Data Loading

import torch
import torchvision

from lightly.transforms import MoCoV2Transform, utils
from lightly.data import LightlyDataset

# MoCo SSL Model Definition

import pytorch_lightning as pl

# MoCo

from lightly.models import ResNetGenerator
from lightly.models.modules.heads import MoCoProjectionHead
from lightly.loss import NTXentLoss

import torch.nn as nn

from lightly.models.utils import (
    batch_shuffle,
    batch_unshuffle,
    deactivate_requires_grad,
    update_momentum,
)

import copy

import tensorflow

from pytorch_lightning.loggers import CSVLogger # capture lightning training data

import pandas as pd

### Initialize pipeline parameters

In [None]:
# Data loading parameters
num_workers = 8
batch_size = 512

In [None]:
# MoCo parameters
memory_bank_size = 4096
seed = 1
max_epochs = 100  #2 #5 #100 # 5 to see execution flow, 100 for actual results, to start with

### Data Loading

Dataloaders will load and preprocess the data

##### Paths to png images

In [None]:
path_to_train = "../data_cifar10/train/"
path_to_test = "../data_cifar10/test/"

##### Dataloader for MoCo Training

In [None]:
# disable blur because we're working with tiny images
transform = MoCoV2Transform(
    input_size=32,
    gaussian_blur=0.0,
)

In [None]:
# Use the moco augmentations for training moco
dataset_train_moco = LightlyDataset(input_dir=path_to_train, 
                                    transform=transform)

In [None]:
dataloader_train_moco = torch.utils.data.DataLoader(
    dataset_train_moco,
    batch_size=batch_size,
    shuffle=True,
    drop_last=True,
    num_workers=num_workers,
)

##### Dataloader for Classifier Training

In [None]:
# Augmentations typically used to train on cifar-10
train_classifier_transforms = torchvision.transforms.Compose(
    [
        torchvision.transforms.RandomCrop(32, padding=4),
        torchvision.transforms.RandomHorizontalFlip(),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(
            mean=utils.IMAGENET_NORMALIZE["mean"],
            std=utils.IMAGENET_NORMALIZE["std"],
        ),
    ]
)


##### Note on augmentations

We will be training a linear classifier using the already-prepared MoCo model, incorporating the same test augmentations. The augmentations from MoCo are potent, often leading to a decrease in accuracy for models not designed for contrastive learning. The training of our linear layer will utilize cross entropy loss, guided by the dataset's labels. As a result, we'll opt for milder augmentations.

In [None]:
dataset_train_classifier = LightlyDataset(
    input_dir=path_to_train, 
    transform=train_classifier_transforms
)

In [None]:
dataloader_train_classifier = torch.utils.data.DataLoader(
    dataset_train_classifier,
    batch_size=batch_size,
    shuffle=True,
    drop_last=True,
    num_workers=num_workers,
)

##### Dataloader for test data

In [None]:
# No additional augmentations for the test set
test_transforms = torchvision.transforms.Compose(
    [
        torchvision.transforms.Resize((32, 32)),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(
            mean=utils.IMAGENET_NORMALIZE["mean"],
            std=utils.IMAGENET_NORMALIZE["std"],
        ),
    ]
)

In [None]:
dataset_test = LightlyDataset(input_dir=path_to_test, 
                              transform=test_transforms)

In [None]:
dataloader_test = torch.utils.data.DataLoader(
    dataset_test,
    batch_size=batch_size,
    shuffle=False,
    drop_last=False,
    num_workers=num_workers,
)

### MoCo Model Definition with PyTorch-Lightning

In [None]:
class MocoModel(pl.LightningModule):
    def __init__(self):
        super().__init__()

        # create a ResNet backbone and remove the classification head
        resnet = ResNetGenerator("resnet-18", 1, num_splits=8)
        self.backbone = nn.Sequential(
            *list(resnet.children())[:-1],
            nn.AdaptiveAvgPool2d(1),
        )

        # create a moco model based on ResNet
        self.projection_head = MoCoProjectionHead(512, 512, 128)
        self.backbone_momentum = copy.deepcopy(self.backbone)
        self.projection_head_momentum = copy.deepcopy(self.projection_head)
        deactivate_requires_grad(self.backbone_momentum)
        deactivate_requires_grad(self.projection_head_momentum)

        # create our loss with the optional memory bank
        self.criterion = NTXentLoss(
            temperature=0.1, memory_bank_size=(memory_bank_size, 128)
        )

    def training_step(self, batch, batch_idx):
        (x_q, x_k), _, _ = batch

        # update momentum
        update_momentum(self.backbone, self.backbone_momentum, 0.99)
        update_momentum(self.projection_head, self.projection_head_momentum, 0.99)

        # get queries
        q = self.backbone(x_q).flatten(start_dim=1)
        q = self.projection_head(q)

        # get keys
        k, shuffle = batch_shuffle(x_k)
        k = self.backbone_momentum(k).flatten(start_dim=1)
        k = self.projection_head_momentum(k)
        k = batch_unshuffle(k, shuffle)

        loss = self.criterion(q, k)
        self.log("train_loss_ssl", loss)
        return loss

    def on_train_epoch_end(self):
        self.custom_histogram_weights()

    # We provide a helper method to log weights in tensorboard
    # which is useful for debugging.
    def custom_histogram_weights(self):
        for name, params in self.named_parameters():
            self.logger.experiment.add_histogram(name, params, self.current_epoch)

    def configure_optimizers(self):
        optim = torch.optim.SGD(
            self.parameters(),
            lr=6e-2,
            momentum=0.9,
            weight_decay=5e-4,
        )
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, max_epochs)
        return [optim], [scheduler]

In [None]:
# Instantiate the MoCo model
model = MocoModel()

In [None]:
# Train MoCo model using the lightning trainer

trainer = pl.Trainer(max_epochs=max_epochs, devices=1, accelerator="cpu") 

# trainer = pl.Trainer(max_epochs=max_epochs, devices=1, accelerator="gpu")
# MisconfigurationException: No supported gpu backend found!

In [None]:
trainer.fit(model, dataloader_train_moco)

### Build a Linear Classifier Using Extracted MoCo Features

In [None]:
class Classifier(pl.LightningModule):
    def __init__(self, backbone):
        super().__init__()
        # use the pretrained ResNet backbone
        self.backbone = backbone

        # freeze the backbone
        deactivate_requires_grad(backbone)

        # create a linear layer for our downstream classification model
        self.fc = nn.Linear(512, 10)

        self.criterion = nn.CrossEntropyLoss()
        self.validation_step_outputs = []

    def forward(self, x):
        y_hat = self.backbone(x).flatten(start_dim=1)
        y_hat = self.fc(y_hat)
        return y_hat

    def training_step(self, batch, batch_idx):
        x, y, _ = batch
        y_hat = self.forward(x)
        loss = self.criterion(y_hat, y)
        self.log("train_loss_fc", loss)
        return loss

    def on_train_epoch_end(self):
        self.custom_histogram_weights()

    # We provide a helper method to log weights in tensorboard
    # which is useful for debugging.
    def custom_histogram_weights(self):
        for name, params in self.named_parameters():
            self.logger.experiment.add_histogram(name, params, self.current_epoch)

    def validation_step(self, batch, batch_idx):
        x, y, _ = batch
        y_hat = self.forward(x)
        y_hat = torch.nn.functional.softmax(y_hat, dim=1)

        # calculate number of correct predictions
        _, predicted = torch.max(y_hat, 1)
        num = predicted.shape[0]
        correct = (predicted == y).float().sum()
        self.validation_step_outputs.append((num, correct))
        return num, correct

    def on_validation_epoch_end(self):
        # calculate and log top1 accuracy
        if self.validation_step_outputs:
            total_num = 0
            total_correct = 0
            for num, correct in self.validation_step_outputs:
                total_num += num
                total_correct += correct
            acc = total_correct / total_num
            self.log("val_acc", acc, on_epoch=True, prog_bar=True)
            self.validation_step_outputs.clear()

    def configure_optimizers(self):
        optim = torch.optim.SGD(self.fc.parameters(), lr=30.0)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, max_epochs)
        return [optim], [scheduler]

In [None]:
model.eval()

In [None]:
classifier = Classifier(model.backbone)

In [None]:
# trainer = pl.Trainer(max_epochs=max_epochs, devices=1, accelerator="gpu")
trained = []
#trainer = pl.Trainer(max_epochs=10, devices=1, accelerator="gpu")

trainer = pl.Trainer(max_epochs=10, 
                     devices=1, 
                     accelerator="gpu", # "auto"
                     logger=CSVLogger(save_dir="logs2/"),
                    )

In [None]:
trainer.fit(classifier, 
            dataloader_train_classifier, 
            dataloader_test)

#output_train1 = trainer.fit(classifier, 
#                            dataloader_train_classifier, 
#                            dataloader_test)

In [None]:
# type(output_train1)
trainer.logger.log_dir

In [None]:
metrics = pd.read_csv(f"{trainer.logger.log_dir}/metrics.csv")

# see PyToch lightning documentation for how to extract csv of training output
# https://lightning.ai/docs/pytorch/stable/common/trainer.html

### Conclusion

Lightly shows a lot of potential for simplifying SSL workflows, and the integration of Lightly with PyTorch-Lightning is a great idea. The Lightly team should work on additional documentation for extracting SSL training results and documentation their impact on downstream tasks such as image classification. 