# Self supervised learning
* Author: Jing Zhang
* Date: 2024/10/07
* reference:https://docs.lightly.ai/self-supervised-learning/index.html
![overview](../figures/lightly_overview.png)    
* installation:    
`pip install lightly`

Data load

In [4]:
import torch
import lightly.data as data

from lightly.transforms.simclr_transform import SimCLRTransform # augmentation methods for specific models, e.g. SimCLR

# The following transform will return two augmented images per input image.
transform = SimCLRTransform()

# Create a dataset from your image folder.
dataset = data.LightlyDataset(
    input_dir='path/unlabeled_data/',
    transform=transform,
)

# Build a PyTorch dataloader.
dataloader = torch.utils.data.DataLoader(
    dataset,                # Pass the dataset to the dataloader.
    batch_size=128,         # A large batch size helps with learning.
    shuffle=True,           # Shuffling is important!
)

Deploy model

In [8]:
import torchvision
from lightly.loss import NTXentLoss # many losses are supported
from lightly.models.modules.heads import SimCLRProjectionHead # Use MLP to map input features to a lower dimensional space
import timm

# torchvision to load resnet backbone
#resnet = torchvision.models.resnet18()
#resnet = torch.nn.Sequential(*list(resnet.children())[:-1]) # remove last classifier layer

# timm lib to load resnet backbone
resnet = timm.create_model('resnet18', pretrained=True)
resnet.reset_classifier(0) # remove last classifier layer

# build a SimCLR model
class SimCLR(torch.nn.Module):
    def __init__(self, backbone, hidden_dim, out_dim):
        super().__init__()
        self.backbone = backbone
        self.projection_head = SimCLRProjectionHead(hidden_dim, hidden_dim, out_dim)

    def forward(self, x):
        h = self.backbone(x).flatten(start_dim=1)
        z = self.projection_head(h)
        return z

model = SimCLR(resnet, hidden_dim=512, out_dim=128)
criterion = NTXentLoss(temperature=0.5) # (normalized temperature-scaled cross entropy loss)
optimizer = torch.optim.SGD(model.parameters(), lr=1e-0, weight_decay=1e-5)

Launch training

In [12]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
max_epochs = 10
for epoch in range(max_epochs):
    for (x0, x1), _, _ in dataloader:

        x0 = x0.to(device)
        x1 = x1.to(device)

        z0 = model(x0)
        z1 = model(x1)

        loss = criterion(z0, z1)
        loss.backward()

        optimizer.step()
        optimizer.zero_grad()
    print(f'Epoch {epoch+1}/{max_epochs}:{loss}')

Epoch 1/10:5.30142879486084
Epoch 2/10:5.437026500701904
Epoch 3/10:5.502232074737549
Epoch 4/10:5.425705432891846
Epoch 5/10:5.390667915344238
Epoch 6/10:5.379733562469482
Epoch 7/10:5.440232276916504
Epoch 8/10:5.428023338317871
Epoch 9/10:5.513458728790283
Epoch 10/10:5.411945819854736


Embedding save

In [14]:
# make a new dataloader without the transformations
# The only transformation needed is to make a torch tensor out of the PIL image
dataset.transform = torchvision.transforms.ToTensor()
dataloader = torch.utils.data.DataLoader(
    dataset,        # use the same dataset as before
    batch_size=1,   # we can use batch size 1 for inference
    shuffle=False,  # don't shuffle your data during inference
)

# embed your image dataset
# Use the trained model to extract features and characterize new data to support downstream tasks such as reasoning, retrieval, and clustering
embeddings = [] 
model.eval()
with torch.no_grad():
    for img, label, fnames in dataloader:
        img = img.to(model.device)
        emb = model.backbone(img).flatten(start_dim=1)
        embeddings.append(emb)

    embeddings = torch.cat(embeddings, 0)
print(embeddings.shape)

torch.Size([1000, 512])


Pytorch lightning framework

In [13]:
import pytorch_lightning as pl

# class SimCLR(torch.nn.Module):
#     def __init__(self, backbone, hidden_dim, out_dim):
#         super().__init__()
#         self.backbone = backbone
#         self.projection_head = SimCLRProjectionHead(hidden_dim, hidden_dim, out_dim)

#     def forward(self, x):
#         h = self.backbone(x).flatten(start_dim=1)
#         z = self.projection_head(h)
#         return z
        
class SimCLR(pl.LightningModule):
    def __init__(self, backbone, hidden_dim, out_dim):
        super().__init__()
        self.backbone = backbone
        self.projection_head = SimCLRProjectionHead(hidden_dim, hidden_dim, out_dim)
        self.criterion = NTXentLoss(temperature=0.5)

    def forward(self, x):
        h = self.backbone(x).flatten(start_dim=1)
        z = self.projection_head(h)
        return z

    def training_step(self, batch, batch_idx):
        (x0, x1), _, _ = batch
        z0 = self.forward(x0)
        z1 = self.forward(x1)
        loss = self.criterion(z0, z1)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.SGD(self.parameters(), lr=1e-0)
        return optimizer

max_epochs = 10

model = SimCLR(resnet, hidden_dim=512, out_dim=128)
trainer = pl.Trainer(max_epochs=max_epochs, devices=1, accelerator="gpu")
trainer.fit(model, dataloader)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name            | Type                 | Params | Mode 
-----------------------------------------------------------------
0 | backbone        | ResNet               | 11.2 M | train
1 | projection_head | SimCLRProjectionHead | 328 K  | train
2 | criterion       | NTXentLoss           | 0      | train
-----------------------------------------------------------------
11.5 M    Trainable params
0         Non-trainable params
11.5 M    Total params
46.022    Total estimated model params size (MB)
103       Modules in train mode
0         Modules in eval mode
c:\Users\zhang\miniconda3\Lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLo

Training: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=10` reached.


## Concepts

transforms

In [None]:
from lightly.transforms import SimCLRTransform
# customize the transform
transform = SimCLRTransform(
    input_size=128,   # resize input images to 128x128 pixels
    cj_prob=0.0,      # disable color jittering
    rr_prob=0.5,      # apply random rotation by 90 degrees with 50% probability
)

In [None]:
from torchvision import transforms as T
from lightly.transforms.multi_view_transform import MultiViewTransform

# Create a global view transform that crops 224x224 patches from the input image.
global_view = T.Compose([
    T.RandomResizedCrop(size=224, scale=(0.08, 1.0)),
    T.RandomHorizontalFlip(p=0.5),
    T.RandomGrayscale(p=0.5),
    T.ToTensor(),
])

# Create a local view transform that crops a random portion of the input image and resizes it to a 96x96 patch.
local_view = T.Compose([
    T.RandomResizedCrop(size=96, scale=(0.05, 0.4)),
    T.RandomHorizontalFlip(p=0.5),
    T.RandomGrayscale(p=0.5),
    T.ToTensor(),
])

# Combine the transforms. Every transform will create one view.
# The final transform will create four views: two global and two local views.
transform = MultiViewTransform([global_view, global_view, local_view, local_view])
views = transform(image) # return 4 transformed views