In [1]:
# Note: The model and training settings do not follow the reference settings
# from the paper. The settings are chosen such that the example can easily be
# run on a small dataset with a single GPU.

import torch
import torchvision
from torch import nn

from lightly.data import LightlyDataset
from lightly.data.multi_view_collate import MultiViewCollate
from lightly.loss import NegativeCosineSimilarity
from lightly.models.modules import SimSiamPredictionHead, SimSiamProjectionHead
from lightly.transforms import SimSiamTransform

In [2]:
class SimSiam(nn.Module):
    def __init__(self, backbone):
        super().__init__()
        self.backbone = backbone
        self.projection_head = SimSiamProjectionHead(512, 512, 128)
        self.prediction_head = SimSiamPredictionHead(128, 64, 128)

    def forward(self, x):
        f = self.backbone(x).flatten(start_dim=1)
        z = self.projection_head(f)
        p = self.prediction_head(z)
        z = z.detach()
        return z, p


In [3]:
resnet = torchvision.models.resnet18()
backbone = nn.Sequential(*list(resnet.children())[:-1])
model = SimSiam(backbone)

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

SimSiam(
  (backbone): Sequential(
    (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (4): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True

In [4]:
cifar10 = torchvision.datasets.CIFAR10("datasets/cifar10", download=True)
transform = SimSiamTransform(input_size=32)
dataset = LightlyDataset.from_torch_dataset(cifar10, transform=transform)
# or create a dataset from a folder containing images or videos:
# dataset = LightlyDataset("path/to/folder")

Files already downloaded and verified


In [5]:
print(type(transform))
print(transform)

<class 'lightly.transforms.simsiam_transform.SimSiamTransform'>
<lightly.transforms.simsiam_transform.SimSiamTransform object at 0x7f5c5e649820>


In [6]:
print(type(cifar10))
print(cifar10.data.shape)

<class 'torchvision.datasets.cifar.CIFAR10'>
(50000, 32, 32, 3)


In [7]:
collate_fn = MultiViewCollate()

In [8]:
dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=256,
    collate_fn=collate_fn,
    shuffle=True,
    drop_last=True,
    num_workers=8,
)

In [9]:
criterion = NegativeCosineSimilarity()
optimizer = torch.optim.SGD(model.parameters(), lr=0.06)

In [10]:
print("Starting Training")
for epoch in range(10):
    total_loss = 0
    for (x0, x1), _, _ in dataloader:
        x0 = x0.to(device)
        x1 = x1.to(device)
        # print("x0 shape=",x0.shape)
        # print("x1 shape=", x1.shape)
        z0, p0 = model(x0)
        # print("z0 shape=", z0.shape)
        # print("p0 shape=", p0.shape)
        z1, p1 = model(x1)
        # print("z1 shape=", z1.shape)
        # print("p1 shape=", p1.shape)
        loss = 0.5 * (criterion(z0, p1) + criterion(z1, p0))
        total_loss += loss.detach()
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
    avg_loss = total_loss / len(dataloader)
    print(f"epoch: {epoch:>02}, loss: {avg_loss:.5f}")


Starting Training
epoch: 00, loss: -0.34282
epoch: 01, loss: -0.67073
epoch: 02, loss: -0.70016
epoch: 03, loss: -0.72145
epoch: 04, loss: -0.73535
epoch: 05, loss: -0.75015
epoch: 06, loss: -0.75681
epoch: 07, loss: -0.76152
epoch: 08, loss: -0.76474
epoch: 09, loss: -0.76500
