This example requires the following dependencies to be installed:
pip install "lightly[timm]"

In [1]:
!pip install "lightly[timm]"

Collecting lightly[timm]
  Downloading lightly-1.5.15-py3-none-any.whl.metadata (36 kB)
Collecting hydra-core>=1.0.0 (from lightly[timm])
  Downloading hydra_core-1.3.2-py3-none-any.whl.metadata (5.5 kB)
Collecting lightly_utils~=0.0.0 (from lightly[timm])
  Downloading lightly_utils-0.0.2-py3-none-any.whl.metadata (1.4 kB)
Collecting pytorch_lightning>=1.0.4 (from lightly[timm])
  Downloading pytorch_lightning-2.4.0-py3-none-any.whl.metadata (21 kB)
Collecting aenum>=3.1.11 (from lightly[timm])
  Downloading aenum-3.1.15-py3-none-any.whl.metadata (3.7 kB)
Collecting omegaconf<2.4,>=2.2 (from hydra-core>=1.0.0->lightly[timm])
  Downloading omegaconf-2.3.0-py3-none-any.whl.metadata (3.9 kB)
Collecting antlr4-python3-runtime==4.9.* (from hydra-core>=1.0.0->lightly[timm])
  Downloading antlr4-python3-runtime-4.9.3.tar.gz (117 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m117.0/117.0 kB[0m [31m10.1 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) 

Note: The model and training settings do not follow the reference settings
from the paper. The settings are chosen such that the example can easily be
run on a small dataset with a single GPU.

In [1]:
import torch
import torchvision
from torch import nn

In [2]:
from lightly.models import utils
from lightly.models.modules import AIMPredictionHead, MaskedCausalVisionTransformer
from lightly.transforms import AIMTransform



In [3]:
class AIM(nn.Module):
    def __init__(self, vit):
        super().__init__()
        utils.initialize_2d_sine_cosine_positional_embedding(
            pos_embedding=vit.pos_embed, has_class_token=vit.has_class_token
        )
        self.patch_size = vit.patch_embed.patch_size[0]
        self.num_patches = vit.patch_embed.num_patches

        self.backbone = vit
        self.projection_head = AIMPredictionHead(
            input_dim=vit.embed_dim, output_dim=3 * self.patch_size**2, num_blocks=1
        )

    def forward(self, images):
        batch_size = images.shape[0]

        mask = utils.random_prefix_mask(
            size=(batch_size, self.num_patches),
            max_prefix_length=self.num_patches - 1,
            device=images.device,
        )
        features = self.backbone.forward_features(images, mask=mask)
        # Add positional embedding before head.
        features = self.backbone._pos_embed(features)
        predictions = self.projection_head(features)

        # Convert images to patches and normalize them.
        patches = utils.patchify(images, self.patch_size)
        patches = utils.normalize_mean_var(patches, dim=-1)

        return predictions, patches

In [4]:
vit = MaskedCausalVisionTransformer(
    img_size=224,
    patch_size=32,
    embed_dim=768,
    depth=12,
    num_heads=12,
    qk_norm=False,
    class_token=False,
    no_embed_class=True,
)
model = AIM(vit)

In [5]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

AIM(
  (backbone): MaskedCausalVisionTransformer(
    (patch_embed): PatchEmbed(
      (proj): Conv2d(3, 768, kernel_size=(32, 32), stride=(32, 32))
      (norm): Identity()
    )
    (pos_drop): Dropout(p=0.0, inplace=False)
    (patch_drop): Identity()
    (norm_pre): Identity()
    (blocks): Sequential(
      (0): MaskedCausalBlock(
        (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (attn): MaskedCausalAttention(
          (qkv): Linear(in_features=768, out_features=2304, bias=True)
          (q_norm): Identity()
          (k_norm): Identity()
          (attn_drop): Dropout(p=0.0, inplace=False)
          (proj): Linear(in_features=768, out_features=768, bias=True)
          (proj_drop): Dropout(p=0.0, inplace=False)
        )
        (ls1): Identity()
        (drop_path1): Identity()
        (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (mlp): Mlp(
          (fc1): Linear(in_features=768, out_features=3072, bias=True)
          (

In [6]:
transform = AIMTransform()
# we ignore object detection annotations by setting target_transform to return 0

In [7]:
def target_transform(t):
    return 0

In [8]:
dataset = torchvision.datasets.VOCDetection(
    "datasets/pascal_voc",
    download=True,
    transform=transform,
    target_transform=target_transform,
)
# or create a dataset from a folder containing images or videos:
# dataset = LightlyDataset("path/to/folder")

Downloading http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar to datasets/pascal_voc/VOCtrainval_11-May-2012.tar


100%|██████████| 2.00G/2.00G [00:07<00:00, 251MB/s]


Extracting datasets/pascal_voc/VOCtrainval_11-May-2012.tar to datasets/pascal_voc


In [9]:
dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=256,
    shuffle=True,
    drop_last=True,
    num_workers=8,
)



In [10]:
criterion = nn.MSELoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=1.5e-4)

In [11]:
print("Starting Training")
for epoch in range(10):
    total_loss = 0
    for batch in dataloader:
        views = batch[0]
        images = views[0].to(device)  # views contains only a single view
        predictions, targets = model(images)
        loss = criterion(predictions, targets)
        total_loss += loss.detach()
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
    avg_loss = total_loss / len(dataloader)
    print(f"epoch: {epoch:>02}, loss: {avg_loss:.5f}")

Starting Training
epoch: 00, loss: 0.78826
epoch: 01, loss: 0.57274
epoch: 02, loss: 0.48636
epoch: 03, loss: 0.44968
epoch: 04, loss: 0.42642
epoch: 05, loss: 0.41471
epoch: 06, loss: 0.39830
epoch: 07, loss: 0.37695
epoch: 08, loss: 0.35189
epoch: 09, loss: 0.32972
