In [None]:
!pip install pytorch_lightning

In [25]:
from typing import Any
from pathlib import Path

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from torchvision.datasets import CIFAR100, STL10
from torchvision import transforms
from torchvision.models import resnet18

import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor

In [26]:
class LoadUnlabelData(Dataset):
    def __init__(self, dataset):
        super(LoadUnlabelData, self).__init__()
        self.dataset = dataset

    def __getitem__(self, index) -> Any:
        data = list(self.dataset[index])
        data[1] = -1
        return data

    def __len__(self):
        return len(self.dataset)

class ContrastiveTransformations(object):
    def __init__(self, base_transforms, n_views=2):
        self.base_transforms = base_transforms
        self.n_views = n_views

    def __call__(self, x):
        return [self.base_transforms(x) for i in range(self.n_views)]

In [27]:
class SimCLR(pl.LightningModule):
    def __init__(self, hidden_dim, learning_rate, temperature, weight_decay, max_epochs=100):
        super(SimCLR, self).__init__()
        self.save_hyperparameters()
        assert self.hparams.temperature > 0.0, f'The temperature = {temperature} smaller than 0!'

        self.convnet = resnet18(num_classes=4*hidden_dim)

        self.convnet.fc = nn.Sequential(
            self.convnet.fc,  # Linear(Resnet output, 4*hidden_dim)
            nn.ReLU(inplace=True), nn.Linear(4*hidden_dim, hidden_dim)
        )

    def configure_optimizers(self):
        optimzier = torch.optim.AdamW(params=self.convnet.parameters(), lr=self.hparams.learning_rate,
                                      weight_decay=self.hparams.weight_decay)
        lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer=optimzier, T_max=self.hparams.max_epochs,
                                                                  eta_min=self.hparams.learning_rate/50)
        return [optimzier], [lr_scheduler]

    def training_step(self, batch, batch_idx):
        return self.info_nce_loss(batch, mode='train')

    def validation_step(self, batch, batch_idx):
        return self.info_nce_loss(batch, mode='val')

    def info_nce_loss(self, batch, mode='train'):
        imgs, _ = batch
        imgs = torch.cat(imgs, dim=0)

        features = self.convnet(imgs)  # extract feature by resnet model
        cos_sim = torch.nn.functional.cosine_similarity(features[:, None, :], features[None, :, :], dim=-1)
        self_mask = torch.eye(cos_sim.shape[0], dtype=torch.bool, device=cos_sim.device)
        cos_sim.masked_fill_(self_mask, value=-9e15)

        # Find positive sample -> batch_size // 2 away from the original example
        pos_mask = self_mask.roll(shifts=cos_sim.shape[0] // 2, dims=0)

        cos_sim = cos_sim / self.hparams.temperature  # infoNCE loss
        nll = -cos_sim[pos_mask] + torch.logsumexp(cos_sim, dim=-1)
        nll = nll.mean()

        # get ranking positive example
        comb_sim = torch.cat([cos_sim[pos_mask][:, None], cos_sim.masked_fill(pos_mask, -9e15)], dim=-1)
        sim_argsort = comb_sim.argsort(dim=-1, descending=True,).argmin(dim=-1)

        # logging loss
        self.log(mode + '_loss', nll)
        self.log(mode + '_acc_top1', (sim_argsort==0).float().mean())
        self.log(mode + '_acc_top5', (sim_argsort < 5).float().mean())
        self.log(mode + '_acc_mean_pos', 1 + sim_argsort.float().mean())
        return nll

In [None]:
if __name__ == '__main__':
    # %% Data preparation for self-supvervised learning
    contrast_transforms = transforms.Compose([transforms.RandomHorizontalFlip(), transforms.RandomResizedCrop(size=96),
                                            transforms.RandomApply([transforms.ColorJitter(brightness=0.5, contrast=0.5,
                                                                                            saturation=0.5, hue=0.1)], p=0.8),
                                            transforms.RandomGrayscale(p=0.2), transforms.GaussianBlur(kernel_size=9),
                                            transforms.ToTensor(), transforms.Normalize(mean=0.5, std=0.5)])

    unlabeled_data = CIFAR100(root='.', download=True, train=True,
                        transform=ContrastiveTransformations(base_transforms=contrast_transforms, n_views=2))
    unlabeled_data = LoadUnlabelData(unlabeled_data)
    train_data_contrast = CIFAR100(root='.', download=True, train=True,
                                transform=ContrastiveTransformations(base_transforms=contrast_transforms, n_views=2))

    batch_size, max_epochs = 256, 10
    hidden_dim, learning_rate, temperature, weight_decay = 128, 5e-4, 0.07, 1e-4

    trainer = pl.Trainer(default_root_dir=Path('.', 'SimCLR'),
                        accelerator='gpu', devices=1, max_epochs=10,
                        callbacks=[ModelCheckpoint(save_weights_only=True, mode='max', monitor='val_acc_top5'),
                                    LearningRateMonitor('epoch')])
    trainer.logger._default_hp_metric = None

    train_dataloader = DataLoader(dataset=unlabeled_data, batch_size=batch_size, shuffle=True,
                                    drop_last=True, pin_memory=True, num_workers=9, persistent_workers=True)
    val_dataloader = DataLoader(train_data_contrast, batch_size=batch_size, shuffle=False,
                            drop_last=False, pin_memory=True, num_workers=9, persistent_workers=True)
    simclr_model = SimCLR(hidden_dim=hidden_dim, learning_rate=learning_rate, temperature=temperature,
                            weight_decay=weight_decay, max_epochs=max_epochs)
    trainer.fit(model=simclr_model, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader)

Files already downloaded and verified
Files already downloaded and verified


INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name    | Type   | Params
-----------------------------------
0 | convnet | ResNet | 11.5 M
-----------------------------------
11.5 M    Trainable params
0         Non-trainable params
11.5 M    Total params
46.019    Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]