In [1]:
from datasets import load_dataset
import torch
import torch.nn as nn
import torchvision
import pytorch_lightning as pl
import copy
import lightly

from lightly.models.modules.heads import MoCoProjectionHead
from lightly.models.utils import deactivate_requires_grad
from lightly.models.utils import update_momentum
from lightly.models.utils import batch_shuffle
from lightly.models.utils import batch_unshuffle

D:\miniconda\envs\diffusers\lib\site-packages\numpy\.libs\libopenblas.FB5AE2TYXYH2IJRDKGDGQ3XBKLKTF43H.gfortran-win_amd64.dll
D:\miniconda\envs\diffusers\lib\site-packages\numpy\.libs\libopenblas.GK7GX5KEQ4F6UYO3P26ULGBQYHGQO7J4.gfortran-win_amd64.dll


In [7]:
num_workers = 8
batch_size = 512
memory_bank_size = 4096
seed = 1
max_epochs = 100

In [2]:
dataset = load_dataset("Isamu136/big-animal-dataset-with-l14")

Using custom data configuration Isamu136--big-animal-dataset-with-l14-e6365e3df6462f2d
Found cached dataset parquet (D:/cache/huggingface/datasets/Isamu136___parquet/Isamu136--big-animal-dataset-with-l14-e6365e3df6462f2d/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)


  0%|          | 0/1 [00:00<?, ?it/s]

In [3]:
dataset

DatasetDict({
    train: Dataset({
        features: ['image', 'caption', 'embeddings'],
        num_rows: 62149
    })
})

In [4]:
dataset.rename_column("embeddings", "l14_embeddings")

DatasetDict({
    train: Dataset({
        features: ['image', 'caption', 'l14_embeddings'],
        num_rows: 62149
    })
})

Moco

In [8]:
class MocoModel(pl.LightningModule):
    def __init__(self):
        super().__init__()

        # create a ResNet backbone and remove the classification head
        resnet = lightly.models.ResNetGenerator('resnet-18', 1, num_splits=8)
        self.backbone = nn.Sequential(
            *list(resnet.children())[:-1],
            nn.AdaptiveAvgPool2d(1),
        )

        # create a moco model based on ResNet
        self.projection_head = MoCoProjectionHead(512, 512, 128)
        self.backbone_momentum = copy.deepcopy(self.backbone)
        self.projection_head_momentum = copy.deepcopy(self.projection_head)
        deactivate_requires_grad(self.backbone_momentum)
        deactivate_requires_grad(self.projection_head_momentum)

        # create our loss with the optional memory bank
        self.criterion = lightly.loss.NTXentLoss(
            temperature=0.1,
            memory_bank_size=memory_bank_size)

    def training_step(self, batch, batch_idx):
        (x_q, x_k), _, _ = batch

        # update momentum
        update_momentum(self.backbone, self.backbone_momentum, 0.99)
        update_momentum(
            self.projection_head, self.projection_head_momentum, 0.99
        )

        # get queries
        q = self.backbone(x_q).flatten(start_dim=1)
        q = self.projection_head(q)

        # get keys
        k, shuffle = batch_shuffle(x_k)
        k = self.backbone_momentum(k).flatten(start_dim=1)
        k = self.projection_head_momentum(k)
        k = batch_unshuffle(k, shuffle)

        loss = self.criterion(q, k)
        self.log("train_loss_ssl", loss)
        return loss

    def training_epoch_end(self, outputs):
        self.custom_histogram_weights()

    # We provide a helper method to log weights in tensorboard
    # which is useful for debugging.
    def custom_histogram_weights(self):
        for name, params in self.named_parameters():
            self.logger.experiment.add_histogram(
                name, params, self.current_epoch)

    def configure_optimizers(self):
        optim = torch.optim.SGD(
            self.parameters(),
            lr=6e-2,
            momentum=0.9,
            weight_decay=5e-4,
        )
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optim, max_epochs
        )
        return [optim], [scheduler]

In [9]:
model = MocoModel()

In [11]:
model.state_dict().keys()

odict_keys(['backbone.0.weight', 'backbone.1.weight', 'backbone.1.bias', 'backbone.1.running_mean', 'backbone.1.running_var', 'backbone.1.num_batches_tracked', 'backbone.2.0.conv1.weight', 'backbone.2.0.bn1.weight', 'backbone.2.0.bn1.bias', 'backbone.2.0.bn1.running_mean', 'backbone.2.0.bn1.running_var', 'backbone.2.0.bn1.num_batches_tracked', 'backbone.2.0.conv2.weight', 'backbone.2.0.bn2.weight', 'backbone.2.0.bn2.bias', 'backbone.2.0.bn2.running_mean', 'backbone.2.0.bn2.running_var', 'backbone.2.0.bn2.num_batches_tracked', 'backbone.2.1.conv1.weight', 'backbone.2.1.bn1.weight', 'backbone.2.1.bn1.bias', 'backbone.2.1.bn1.running_mean', 'backbone.2.1.bn1.running_var', 'backbone.2.1.bn1.num_batches_tracked', 'backbone.2.1.conv2.weight', 'backbone.2.1.bn2.weight', 'backbone.2.1.bn2.bias', 'backbone.2.1.bn2.running_mean', 'backbone.2.1.bn2.running_var', 'backbone.2.1.bn2.num_batches_tracked', 'backbone.3.0.conv1.weight', 'backbone.3.0.bn1.weight', 'backbone.3.0.bn1.bias', 'backbone.3.0.b