diff --git a/benchmarks/imagenet/vitb16/finetune_eval.py b/benchmarks/imagenet/vitb16/finetune_eval.py index d89e1bfc7..869b01bfe 100644 --- a/benchmarks/imagenet/vitb16/finetune_eval.py +++ b/benchmarks/imagenet/vitb16/finetune_eval.py @@ -36,6 +36,8 @@ def __init__( super().__init__( model, batch_size_per_device, feature_dim, num_classes, topk, freeze_model ) + # TODO(Ersi, 2/24): Add path dropout for TIMM. + # Add path dropout. add_stochastic_depth_to_blocks(self.model, prob=0.1) # Add mixup and cutmix. @@ -140,7 +142,6 @@ def finetune_eval( Parameters follow MAE settings. """ print("Running fine-tune evaluation...") - # Setup training data. # NOTE: We use transforms from the timm library here as they are the default in MAE # and torchvision does not provide all required parameters. diff --git a/benchmarks/imagenet/vitb16/mae.py b/benchmarks/imagenet/vitb16/mae.py new file mode 100644 index 000000000..29c59585c --- /dev/null +++ b/benchmarks/imagenet/vitb16/mae.py @@ -0,0 +1,155 @@ +import sys +from typing import List, Tuple + +import torch +from pytorch_lightning import LightningModule +from timm.models.vision_transformer import vit_base_patch16_224 +from torch import Tensor +from torch.nn import MSELoss, Parameter +from torch.optim import AdamW + +from lightly.models import utils +from lightly.models.modules import MAEDecoderTIMM, MaskedVisionTransformerTIMM +from lightly.transforms import MAETransform +from lightly.utils.benchmarking import OnlineLinearClassifier +from lightly.utils.scheduler import CosineWarmupScheduler + + +class MAE(LightningModule): + def __init__(self, batch_size_per_device: int, num_classes: int) -> None: + super().__init__() + self.save_hyperparameters() + self.batch_size_per_device = batch_size_per_device + + decoder_dim = 512 + vit = vit_base_patch16_224() + + self.mask_ratio = 0.75 + self.patch_size = vit.patch_embed.patch_size[0] + self.sequence_length = vit.patch_embed.num_patches + vit.num_prefix_tokens + mask_token = Parameter(torch.zeros(1, 1, decoder_dim)) + torch.nn.init.normal_(mask_token, std=0.02) + self.backbone = MaskedVisionTransformerTIMM(vit=vit) + self.decoder = MAEDecoderTIMM( + num_patches=vit.patch_embed.num_patches, + patch_size=self.patch_size, + embed_dim=vit.embed_dim, + decoder_embed_dim=decoder_dim, + decoder_depth=8, + decoder_num_heads=16, + mlp_ratio=4.0, + proj_drop_rate=0.0, + attn_drop_rate=0.0, + mask_token=mask_token, + ) + self.criterion = MSELoss() + + self.online_classifier = OnlineLinearClassifier( + feature_dim=768, num_classes=num_classes + ) + + def forward(self, x: Tensor) -> Tensor: + return self.backbone(images=x) + + def forward_encoder(self, images, idx_keep=None): + return self.backbone.encode(images=images, idx_keep=idx_keep) + + def forward_decoder(self, x_encoded, idx_keep, idx_mask): + # build decoder input + batch_size = x_encoded.shape[0] + x_decode = self.decoder.embed(x_encoded) + x_masked = utils.repeat_token( + self.decoder.mask_token, (batch_size, self.sequence_length) + ) + x_masked = utils.set_at_index(x_masked, idx_keep, x_decode.type_as(x_masked)) + + # decoder forward pass + x_decoded = self.decoder.decode(x_masked) + + # predict pixel values for masked tokens + x_pred = utils.get_at_index(x_decoded, idx_mask) + x_pred = self.decoder.predict(x_pred) + return x_pred + + def training_step( + self, batch: Tuple[List[Tensor], Tensor, List[str]], batch_idx: int + ) -> Tensor: + images, targets = batch[0], batch[1] + images = images[0] # images is a list containing only one view + batch_size = images.shape[0] + idx_keep, idx_mask = utils.random_token_mask( + size=(batch_size, self.sequence_length), + mask_ratio=self.mask_ratio, + device=images.device, + ) + features = self.forward_encoder(images, idx_keep) + predictions = self.forward_decoder(features, idx_keep, idx_mask) + + # get image patches for masked tokens + patches = utils.patchify(images, self.patch_size) + # must adjust idx_mask for missing class token + target = utils.get_at_index(patches, idx_mask - 1) + + loss = self.criterion(predictions, target) + self.log( + "train_loss", loss, prog_bar=True, sync_dist=True, batch_size=len(targets) + ) + + cls_features = features[:, 0] + cls_loss, cls_log = self.online_classifier.training_step( + (cls_features.detach(), targets), batch_idx + ) + self.log_dict(cls_log, sync_dist=True, batch_size=len(targets)) + return loss + cls_loss + + def validation_step( + self, batch: Tuple[Tensor, Tensor, List[str]], batch_idx: int + ) -> Tensor: + images, targets = batch[0], batch[1] + cls_features = self.forward(images).flatten(start_dim=1) + cls_loss, cls_log = self.online_classifier.validation_step( + (cls_features.detach(), targets), batch_idx + ) + self.log_dict(cls_log, prog_bar=True, sync_dist=True, batch_size=len(targets)) + return cls_loss + + def configure_optimizers(self): + # Don't use weight decay for batch norm, bias parameters, and classification + # head to improve performance. + params, params_no_weight_decay = utils.get_weight_decay_parameters( + [self.backbone, self.decoder] + ) + optimizer = AdamW( + [ + {"name": "mae", "params": params}, + { + "name": "mae_no_weight_decay", + "params": params_no_weight_decay, + "weight_decay": 0.0, + }, + { + "name": "online_classifier", + "params": self.online_classifier.parameters(), + "weight_decay": 0.0, + }, + ], + lr=1.5e-4 * self.batch_size_per_device * self.trainer.world_size / 256, + weight_decay=0.05, + betas=(0.9, 0.95), + ) + scheduler = { + "scheduler": CosineWarmupScheduler( + optimizer=optimizer, + warmup_epochs=( + self.trainer.estimated_stepping_batches + / self.trainer.max_epochs + * 40 + ), + max_epochs=self.trainer.estimated_stepping_batches, + ), + "interval": "step", + } + return [optimizer], [scheduler] + + +transform = MAETransform() diff --git a/benchmarks/imagenet/vitb16/main.py b/benchmarks/imagenet/vitb16/main.py index ca7967746..8efc8c17f 100644 --- a/benchmarks/imagenet/vitb16/main.py +++ b/benchmarks/imagenet/vitb16/main.py @@ -7,6 +7,7 @@ import finetune_eval import knn_eval import linear_eval +import mae import torch from pytorch_lightning import LightningModule, Trainer from pytorch_lightning.callbacks import DeviceStatsMonitor, LearningRateMonitor @@ -38,7 +39,9 @@ parser.add_argument("--float32-matmul-precision", type=str, default="high") parser.add_argument("--strategy", default="ddp_find_unused_parameters_true") + METHODS = { + "mae": {"model": mae.MAE, "transform": mae.transform}, "aim": {"model": aim.AIM, "transform": aim.transform}, } diff --git a/docs/source/getting_started/benchmarks/imagenette_benchmark.py b/docs/source/getting_started/benchmarks/imagenette_benchmark.py index ee9cc0171..e893cfeae 100644 --- a/docs/source/getting_started/benchmarks/imagenette_benchmark.py +++ b/docs/source/getting_started/benchmarks/imagenette_benchmark.py @@ -61,6 +61,7 @@ """ import copy import os +import sys import time import numpy as np @@ -70,6 +71,8 @@ import torchvision from pl_bolts.optimizers.lars import LARS from pytorch_lightning.loggers import TensorBoardLogger +from timm.models.vision_transformer import vit_base_patch32_224 +from torchvision.models.vision_transformer import VisionTransformer from lightly.data import LightlyDataset from lightly.loss import ( @@ -87,7 +90,13 @@ VICRegLoss, ) from lightly.models import modules, utils -from lightly.models.modules import heads, masked_autoencoder, memory_bank +from lightly.models.modules import ( + MAEDecoderTIMM, + MaskedVisionTransformerTIMM, + MaskedVisionTransformerTorchvision, + heads, + memory_bank, +) from lightly.transforms import ( BYOLTransform, BYOLView1Transform, @@ -135,7 +144,7 @@ # benchmark n_runs = 1 # optional, increase to create multiple runs and report mean + std -batch_size = 256 +batch_size = 128 lr_factor = batch_size / 256 # scales the learning rate linearly with batch size # Number of devices and hardware to use for training. @@ -156,6 +165,7 @@ path_to_train = "/datasets/imagenette2-160/train/" path_to_test = "/datasets/imagenette2-160/val/" + # Use BYOL augmentations byol_transform = BYOLTransform( view_1_transform=BYOLView1Transform(input_size=input_size), @@ -768,37 +778,36 @@ class MAEModel(BenchmarkModule): def __init__(self, dataloader_kNN, num_classes): super().__init__(dataloader_kNN, num_classes) + vit = vit_base_patch32_224(dynamic_img_size=True, dynamic_img_pad=True) decoder_dim = 512 - vit = torchvision.models.vit_b_32(pretrained=False) - self.warmup_epochs = 40 if max_epochs >= 800 else 20 self.mask_ratio = 0.75 - self.patch_size = vit.patch_size - self.sequence_length = vit.seq_length - self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_dim)) - self.backbone = masked_autoencoder.MAEBackbone.from_vit(vit) - self.decoder = masked_autoencoder.MAEDecoder( - seq_length=vit.seq_length, - num_layers=1, - num_heads=16, - embed_input_dim=vit.hidden_dim, - hidden_dim=decoder_dim, - mlp_dim=decoder_dim * 4, - out_dim=vit.patch_size**2 * 3, - dropout=0, - attention_dropout=0, + self.patch_size = vit.patch_embed.patch_size[0] + self.sequence_length = vit.patch_embed.num_patches + 1 + self.backbone = MaskedVisionTransformerTIMM(vit=vit) + self.decoder = MAEDecoderTIMM( + num_patches=vit.patch_embed.num_patches, + patch_size=self.patch_size, + in_chans=3, + embed_dim=vit.embed_dim, + decoder_embed_dim=decoder_dim, + decoder_depth=1, + decoder_num_heads=16, + mlp_ratio=4.0, + proj_drop_rate=0.0, + attn_drop_rate=0.0, ) self.criterion = nn.MSELoss() def forward_encoder(self, images, idx_keep=None): - return self.backbone.encode(images, idx_keep) + return self.backbone.encode(images, idx_keep=idx_keep) def forward_decoder(self, x_encoded, idx_keep, idx_mask): # build decoder input batch_size = x_encoded.shape[0] x_decode = self.decoder.embed(x_encoded) x_masked = utils.repeat_token( - self.mask_token, (batch_size, self.sequence_length) + self.decoder.mask_token, (batch_size, self.sequence_length) ) x_masked = utils.set_at_index(x_masked, idx_keep, x_decode.type_as(x_masked)) @@ -819,7 +828,7 @@ def training_step(self, batch, batch_idx): mask_ratio=self.mask_ratio, device=images.device, ) - x_encoded = self.forward_encoder(images, idx_keep) + x_encoded = self.forward_encoder(images, idx_keep=idx_keep) x_pred = self.forward_decoder(x_encoded, idx_keep, idx_mask) # get image patches for masked tokens @@ -851,7 +860,7 @@ def __init__(self, dataloader_kNN, num_classes): self.warmup_epochs = 15 # ViT small configuration (ViT-S/16) self.mask_ratio = 0.15 - self.backbone = masked_autoencoder.MAEBackbone( + vit = VisionTransformer( image_size=224, patch_size=16, num_layers=12, @@ -859,6 +868,8 @@ def __init__(self, dataloader_kNN, num_classes): hidden_dim=384, mlp_dim=384 * 4, ) + self.backbone = MaskedVisionTransformerTorchvision(vit=vit) + self.projection_head = heads.MSNProjectionHead(384) self.anchor_backbone = copy.deepcopy(self.backbone) @@ -892,13 +903,13 @@ def training_step(self, batch, batch_idx): def encode_masked(self, anchors): batch_size, _, _, width = anchors.shape - seq_length = (width // self.anchor_backbone.patch_size) ** 2 + seq_length = (width // self.anchor_backbone.vit.patch_size) ** 2 idx_keep, _ = utils.random_token_mask( size=(batch_size, seq_length), mask_ratio=self.mask_ratio, device=self.device, ) - out = self.anchor_backbone(anchors, idx_keep) + out = self.anchor_backbone(images=anchors, idx_keep=idx_keep) return self.anchor_projection_head(out) def configure_optimizers(self): @@ -926,7 +937,7 @@ def __init__(self, dataloader_kNN, num_classes): self.warmup_epochs = 15 # ViT small configuration (ViT-S/16) self.mask_ratio = 0.15 - self.backbone = masked_autoencoder.MAEBackbone( + vit = VisionTransformer( image_size=224, patch_size=16, num_layers=12, @@ -934,6 +945,7 @@ def __init__(self, dataloader_kNN, num_classes): hidden_dim=384, mlp_dim=384 * 4, ) + self.backbone = MaskedVisionTransformerTorchvision(vit=vit) self.projection_head = heads.MSNProjectionHead(384) self.anchor_backbone = copy.deepcopy(self.backbone) @@ -967,13 +979,13 @@ def training_step(self, batch, batch_idx): def encode_masked(self, anchors): batch_size, _, _, width = anchors.shape - seq_length = (width // self.anchor_backbone.patch_size) ** 2 + seq_length = (width // self.anchor_backbone.vit.patch_size) ** 2 idx_keep, _ = utils.random_token_mask( size=(batch_size, seq_length), mask_ratio=self.mask_ratio, device=self.device, ) - out = self.anchor_backbone(anchors, idx_keep) + out = self.anchor_backbone(images=anchors, idx_keep=idx_keep) return self.anchor_projection_head(out) def configure_optimizers(self): @@ -1107,10 +1119,12 @@ def __init__(self, dataloader_kNN, num_classes): self.mask_ratio = 0.75 self.patch_size = vit.patch_size self.sequence_length = vit.seq_length - self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_dim)) + mask_token = nn.Parameter(torch.zeros(1, 1, decoder_dim)) - # same backbone as MAE - self.backbone = masked_autoencoder.MAEBackbone.from_vit(vit) + # Masked vision transformer as backbone + self.backbone = MaskedVisionTransformerTorchvision( + vit=vit, mask_token=mask_token + ) # the decoder is a simple linear layer self.decoder = nn.Linear(vit.hidden_dim, vit.patch_size**2 * 3) @@ -1119,10 +1133,7 @@ def __init__(self, dataloader_kNN, num_classes): self.criterion = nn.L1Loss() def forward_encoder(self, images, batch_size, idx_mask): - # pass all the tokens to the encoder, both masked and non masked ones - tokens = self.backbone.images_to_tokens(images, prepend_class_token=True) - tokens_masked = utils.mask_at_index(tokens, idx_mask, self.mask_token) - return self.backbone.encoder(tokens_masked) + return self.backbone.encode(images=images, idx_mask=idx_mask, idx_keep=None) def forward_decoder(self, x_encoded): return self.decoder(x_encoded) @@ -1403,13 +1414,13 @@ def configure_optimizers(self): DCLW, DINOModel, FastSiamModel, - # MAEModel, # disabled by default because MAE uses larger images with size 224 + # MAEModel, # disabled by default because MAE uses larger images with size 224 MSNModel, MocoModel, NNCLRModel, PMSNModel, SimCLRModel, - # SimMIMModel, # disabled by default because SimMIM uses larger images with size 224 + # SimMIMModel, # disabled by default because SimMIM uses larger images with size 224 SimSiamModel, SwaVModel, SwaVQueueModel, @@ -1418,6 +1429,7 @@ def configure_optimizers(self): VICRegModel, VICRegLModel, ] + bench_results = dict() experiment_version = None diff --git a/docs/source/getting_started/install.rst b/docs/source/getting_started/install.rst index 9ea074263..c7cc71e71 100644 --- a/docs/source/getting_started/install.rst +++ b/docs/source/getting_started/install.rst @@ -32,6 +32,13 @@ If you want to work with video files you need to additionally install pip install av +If you want to work use the Masked Autoencoder you need to additionally install +`TIMM `_. + +.. code-block:: bash + + pip install timm + Next Steps ------------ diff --git a/examples/pytorch/mae.py b/examples/pytorch/mae.py index 999404ebb..1f22ae8ec 100644 --- a/examples/pytorch/mae.py +++ b/examples/pytorch/mae.py @@ -1,13 +1,13 @@ # Note: The model and training settings do not follow the reference settings # from the paper. The settings are chosen such that the example can easily be # run on a small dataset with a single GPU. - import torch import torchvision +from timm.models.vision_transformer import vit_base_patch32_224 from torch import nn from lightly.models import utils -from lightly.models.modules import masked_autoencoder +from lightly.models.modules import MAEDecoderTIMM, MaskedVisionTransformerTIMM from lightly.transforms.mae_transform import MAETransform @@ -17,31 +17,31 @@ def __init__(self, vit): decoder_dim = 512 self.mask_ratio = 0.75 - self.patch_size = vit.patch_size - self.sequence_length = vit.seq_length - self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_dim)) - self.backbone = masked_autoencoder.MAEBackbone.from_vit(vit) - self.decoder = masked_autoencoder.MAEDecoder( - seq_length=vit.seq_length, - num_layers=1, - num_heads=16, - embed_input_dim=vit.hidden_dim, - hidden_dim=decoder_dim, - mlp_dim=decoder_dim * 4, - out_dim=vit.patch_size**2 * 3, - dropout=0, - attention_dropout=0, + self.patch_size = vit.patch_embed.patch_size[0] + + self.backbone = MaskedVisionTransformerTIMM(vit=vit) + self.sequence_length = self.backbone.sequence_length + self.decoder = MAEDecoderTIMM( + num_patches=vit.patch_embed.num_patches, + patch_size=self.patch_size, + embed_dim=vit.embed_dim, + decoder_embed_dim=decoder_dim, + decoder_depth=1, + decoder_num_heads=16, + mlp_ratio=4.0, + proj_drop_rate=0.0, + attn_drop_rate=0.0, ) def forward_encoder(self, images, idx_keep=None): - return self.backbone.encode(images, idx_keep) + return self.backbone.encode(images=images, idx_keep=idx_keep) def forward_decoder(self, x_encoded, idx_keep, idx_mask): # build decoder input batch_size = x_encoded.shape[0] x_decode = self.decoder.embed(x_encoded) x_masked = utils.repeat_token( - self.mask_token, (batch_size, self.sequence_length) + self.decoder.mask_token, (batch_size, self.sequence_length) ) x_masked = utils.set_at_index(x_masked, idx_keep, x_decode.type_as(x_masked)) @@ -60,8 +60,10 @@ def forward(self, images): mask_ratio=self.mask_ratio, device=images.device, ) - x_encoded = self.forward_encoder(images, idx_keep) - x_pred = self.forward_decoder(x_encoded, idx_keep, idx_mask) + x_encoded = self.forward_encoder(images=images, idx_keep=idx_keep) + x_pred = self.forward_decoder( + x_encoded=x_encoded, idx_keep=idx_keep, idx_mask=idx_mask + ) # get image patches for masked tokens patches = utils.patchify(images, self.patch_size) @@ -70,7 +72,7 @@ def forward(self, images): return x_pred, target -vit = torchvision.models.vit_b_32(pretrained=False) +vit = vit_base_patch32_224() model = MAE(vit) device = "cuda" if torch.cuda.is_available() else "cpu" diff --git a/examples/pytorch/msn.py b/examples/pytorch/msn.py index 4e7efbd88..f69eb7dae 100644 --- a/examples/pytorch/msn.py +++ b/examples/pytorch/msn.py @@ -9,8 +9,8 @@ from lightly.loss import MSNLoss from lightly.models import utils +from lightly.models.modules import MaskedVisionTransformerTorchvision from lightly.models.modules.heads import MSNProjectionHead -from lightly.models.modules.masked_autoencoder import MAEBackbone from lightly.transforms.msn_transform import MSNTransform @@ -19,7 +19,7 @@ def __init__(self, vit): super().__init__() self.mask_ratio = 0.15 - self.backbone = MAEBackbone.from_vit(vit) + self.backbone = MaskedVisionTransformerTorchvision(vit=vit) self.projection_head = MSNProjectionHead(input_dim=384) self.anchor_backbone = copy.deepcopy(self.backbone) @@ -31,18 +31,18 @@ def __init__(self, vit): self.prototypes = nn.Linear(256, 1024, bias=False).weight def forward(self, images): - out = self.backbone(images) + out = self.backbone(images=images) return self.projection_head(out) def forward_masked(self, images): batch_size, _, _, width = images.shape - seq_length = (width // self.anchor_backbone.patch_size) ** 2 + seq_length = (width // self.anchor_backbone.vit.patch_size) ** 2 idx_keep, _ = utils.random_token_mask( size=(batch_size, seq_length), mask_ratio=self.mask_ratio, device=images.device, ) - out = self.anchor_backbone(images, idx_keep) + out = self.anchor_backbone(images=images, idx_keep=idx_keep) return self.anchor_projection_head(out) diff --git a/examples/pytorch/pmsn.py b/examples/pytorch/pmsn.py index d8d03bd17..545e38fc4 100644 --- a/examples/pytorch/pmsn.py +++ b/examples/pytorch/pmsn.py @@ -9,8 +9,8 @@ from lightly.loss import PMSNLoss from lightly.models import utils +from lightly.models.modules import MaskedVisionTransformerTorchvision from lightly.models.modules.heads import MSNProjectionHead -from lightly.models.modules.masked_autoencoder import MAEBackbone from lightly.transforms import MSNTransform @@ -19,7 +19,7 @@ def __init__(self, vit): super().__init__() self.mask_ratio = 0.15 - self.backbone = MAEBackbone.from_vit(vit) + self.backbone = MaskedVisionTransformerTorchvision(vit=vit) self.projection_head = MSNProjectionHead(384) self.anchor_backbone = copy.deepcopy(self.backbone) @@ -31,18 +31,18 @@ def __init__(self, vit): self.prototypes = nn.Linear(256, 1024, bias=False).weight def forward(self, images): - out = self.backbone(images) + out = self.backbone(images=images) return self.projection_head(out) def forward_masked(self, images): batch_size, _, _, width = images.shape - seq_length = (width // self.anchor_backbone.patch_size) ** 2 + seq_length = (width // self.anchor_backbone.vit.patch_size) ** 2 idx_keep, _ = utils.random_token_mask( size=(batch_size, seq_length), mask_ratio=self.mask_ratio, device=images.device, ) - out = self.anchor_backbone(images, idx_keep) + out = self.anchor_backbone(images=images, idx_keep=idx_keep) return self.anchor_projection_head(out) @@ -106,7 +106,7 @@ def forward_masked(self, images): anchors = views[1] anchors_focal = torch.concat(views[2:], dim=0) - targets_out = model.backbone(targets) + targets_out = model.backbone(images=targets) targets_out = model.projection_head(targets_out) anchors_out = model.forward_masked(anchors) anchors_focal_out = model.forward_masked(anchors_focal) diff --git a/examples/pytorch/simmim.py b/examples/pytorch/simmim.py index c3d660a9c..74d8b7f97 100644 --- a/examples/pytorch/simmim.py +++ b/examples/pytorch/simmim.py @@ -3,7 +3,9 @@ from torch import nn from lightly.models import utils -from lightly.models.modules import masked_autoencoder +from lightly.models.modules.masked_vision_transformer_torchvision import ( + MaskedVisionTransformerTorchvision, +) from lightly.transforms.mae_transform import MAETransform # Same transform as MAE @@ -15,19 +17,15 @@ def __init__(self, vit): self.mask_ratio = 0.75 self.patch_size = vit.patch_size self.sequence_length = vit.seq_length - self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_dim)) - # same backbone as MAE - self.backbone = masked_autoencoder.MAEBackbone.from_vit(vit) + self.backbone = MaskedVisionTransformerTorchvision(vit=vit) # the decoder is a simple linear layer - self.decoder = nn.Linear(vit.hidden_dim, vit.patch_size**2 * 3) + self.decoder = nn.Linear(decoder_dim, vit.patch_size**2 * 3) def forward_encoder(self, images, batch_size, idx_mask): # pass all the tokens to the encoder, both masked and non masked ones - tokens = self.backbone.images_to_tokens(images, prepend_class_token=True) - tokens_masked = utils.mask_at_index(tokens, idx_mask, self.mask_token) - return self.backbone.encoder(tokens_masked) + return self.backbone.encode(images=images, idx_mask=idx_mask) def forward_decoder(self, x_encoded): return self.decoder(x_encoded) diff --git a/examples/pytorch_lightning/mae.py b/examples/pytorch_lightning/mae.py index ca3044eb5..832ebab17 100644 --- a/examples/pytorch_lightning/mae.py +++ b/examples/pytorch_lightning/mae.py @@ -1,14 +1,14 @@ # Note: The model and training settings do not follow the reference settings # from the paper. The settings are chosen such that the example can easily be # run on a small dataset with a single GPU. - import pytorch_lightning as pl import torch import torchvision +from timm.models.vision_transformer import vit_base_patch32_224 from torch import nn from lightly.models import utils -from lightly.models.modules import masked_autoencoder +from lightly.models.modules import MAEDecoderTIMM, MaskedVisionTransformerTIMM from lightly.transforms.mae_transform import MAETransform @@ -17,34 +17,33 @@ def __init__(self): super().__init__() decoder_dim = 512 - vit = torchvision.models.vit_b_32(pretrained=False) + vit = vit_base_patch32_224() self.mask_ratio = 0.75 - self.patch_size = vit.patch_size - self.sequence_length = vit.seq_length - self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_dim)) - self.backbone = masked_autoencoder.MAEBackbone.from_vit(vit) - self.decoder = masked_autoencoder.MAEDecoder( - seq_length=vit.seq_length, - num_layers=1, - num_heads=16, - embed_input_dim=vit.hidden_dim, - hidden_dim=decoder_dim, - mlp_dim=decoder_dim * 4, - out_dim=vit.patch_size**2 * 3, - dropout=0, - attention_dropout=0, + self.patch_size = vit.patch_embed.patch_size[0] + self.backbone = MaskedVisionTransformerTIMM(vit=vit) + self.sequence_length = self.backbone.sequence_length + self.decoder = MAEDecoderTIMM( + num_patches=vit.patch_embed.num_patches, + patch_size=self.patch_size, + embed_dim=vit.embed_dim, + decoder_embed_dim=decoder_dim, + decoder_depth=1, + decoder_num_heads=16, + mlp_ratio=4.0, + proj_drop_rate=0.0, + attn_drop_rate=0.0, ) self.criterion = nn.MSELoss() def forward_encoder(self, images, idx_keep=None): - return self.backbone.encode(images, idx_keep) + return self.backbone.encode(images=images, idx_keep=idx_keep) def forward_decoder(self, x_encoded, idx_keep, idx_mask): # build decoder input batch_size = x_encoded.shape[0] x_decode = self.decoder.embed(x_encoded) x_masked = utils.repeat_token( - self.mask_token, (batch_size, self.sequence_length) + self.decoder.mask_token, (batch_size, self.sequence_length) ) x_masked = utils.set_at_index(x_masked, idx_keep, x_decode.type_as(x_masked)) @@ -65,8 +64,10 @@ def training_step(self, batch, batch_idx): mask_ratio=self.mask_ratio, device=images.device, ) - x_encoded = self.forward_encoder(images, idx_keep) - x_pred = self.forward_decoder(x_encoded, idx_keep, idx_mask) + x_encoded = self.forward_encoder(images=images, idx_keep=idx_keep) + x_pred = self.forward_decoder( + x_encoded=x_encoded, idx_keep=idx_keep, idx_mask=idx_mask + ) # get image patches for masked tokens patches = utils.patchify(images, self.patch_size) diff --git a/examples/pytorch_lightning/msn.py b/examples/pytorch_lightning/msn.py index 61bc3c8c9..91569b215 100644 --- a/examples/pytorch_lightning/msn.py +++ b/examples/pytorch_lightning/msn.py @@ -10,8 +10,8 @@ from lightly.loss import MSNLoss from lightly.models import utils +from lightly.models.modules import MaskedVisionTransformerTorchvision from lightly.models.modules.heads import MSNProjectionHead -from lightly.models.modules.masked_autoencoder import MAEBackbone from lightly.transforms.msn_transform import MSNTransform @@ -21,7 +21,8 @@ def __init__(self): # ViT small configuration (ViT-S/16) self.mask_ratio = 0.15 - self.backbone = MAEBackbone( + # ViT small configuration (ViT-S/16) + vit = torchvision.models.VisionTransformer( image_size=224, patch_size=16, num_layers=12, @@ -29,6 +30,7 @@ def __init__(self): hidden_dim=384, mlp_dim=384 * 4, ) + self.backbone = MaskedVisionTransformerTorchvision(vit=vit) # or use a torchvision ViT backbone: # vit = torchvision.models.vit_b_32(pretrained=False) # self.backbone = MAEBackbone.from_vit(vit) @@ -53,7 +55,7 @@ def training_step(self, batch, batch_idx): anchors = views[1] anchors_focal = torch.concat(views[2:], dim=0) - targets_out = self.backbone(targets) + targets_out = self.backbone(images=targets) targets_out = self.projection_head(targets_out) anchors_out = self.encode_masked(anchors) anchors_focal_out = self.encode_masked(anchors_focal) @@ -64,13 +66,13 @@ def training_step(self, batch, batch_idx): def encode_masked(self, anchors): batch_size, _, _, width = anchors.shape - seq_length = (width // self.anchor_backbone.patch_size) ** 2 + seq_length = (width // self.anchor_backbone.vit.patch_size) ** 2 idx_keep, _ = utils.random_token_mask( size=(batch_size, seq_length), mask_ratio=self.mask_ratio, device=self.device, ) - out = self.anchor_backbone(anchors, idx_keep) + out = self.anchor_backbone(images=anchors, idx_keep=idx_keep) return self.anchor_projection_head(out) def configure_optimizers(self): diff --git a/examples/pytorch_lightning/pmsn.py b/examples/pytorch_lightning/pmsn.py index e1e3ab95d..7506c3c1d 100644 --- a/examples/pytorch_lightning/pmsn.py +++ b/examples/pytorch_lightning/pmsn.py @@ -10,8 +10,8 @@ from lightly.loss import PMSNLoss from lightly.models import utils +from lightly.models.modules import MaskedVisionTransformerTorchvision from lightly.models.modules.heads import MSNProjectionHead -from lightly.models.modules.masked_autoencoder import MAEBackbone from lightly.transforms import MSNTransform @@ -21,7 +21,7 @@ def __init__(self): # ViT small configuration (ViT-S/16) self.mask_ratio = 0.15 - self.backbone = MAEBackbone( + vit = torchvision.models.VisionTransformer( image_size=224, patch_size=16, num_layers=12, @@ -29,6 +29,7 @@ def __init__(self): hidden_dim=384, mlp_dim=384 * 4, ) + self.backbone = MaskedVisionTransformerTorchvision(vit=vit) # or use a torchvision ViT backbone: # vit = torchvision.models.vit_b_32(pretrained=False) # self.backbone = MAEBackbone.from_vit(vit) @@ -53,7 +54,7 @@ def training_step(self, batch, batch_idx): anchors = views[1] anchors_focal = torch.concat(views[2:], dim=0) - targets_out = self.backbone(targets) + targets_out = self.backbone(images=targets) targets_out = self.projection_head(targets_out) anchors_out = self.encode_masked(anchors) anchors_focal_out = self.encode_masked(anchors_focal) @@ -64,13 +65,13 @@ def training_step(self, batch, batch_idx): def encode_masked(self, anchors): batch_size, _, _, width = anchors.shape - seq_length = (width // self.anchor_backbone.patch_size) ** 2 + seq_length = (width // self.anchor_backbone.vit.patch_size) ** 2 idx_keep, _ = utils.random_token_mask( size=(batch_size, seq_length), mask_ratio=self.mask_ratio, device=self.device, ) - out = self.anchor_backbone(anchors, idx_keep) + out = self.anchor_backbone(images=anchors, idx_keep=idx_keep) return self.anchor_projection_head(out) def configure_optimizers(self): diff --git a/examples/pytorch_lightning/simmim.py b/examples/pytorch_lightning/simmim.py index 95a14b21d..00b98b74f 100644 --- a/examples/pytorch_lightning/simmim.py +++ b/examples/pytorch_lightning/simmim.py @@ -4,7 +4,7 @@ from torch import nn from lightly.models import utils -from lightly.models.modules import masked_autoencoder +from lightly.models.modules import MaskedVisionTransformerTorchvision from lightly.transforms.mae_transform import MAETransform # Same transform as MAE @@ -19,8 +19,7 @@ def __init__(self): self.sequence_length = vit.seq_length self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_dim)) - # same backbone as MAE - self.backbone = masked_autoencoder.MAEBackbone.from_vit(vit) + self.backbone = MaskedVisionTransformerTorchvision(vit=vit) # the decoder is a simple linear layer self.decoder = nn.Linear(vit.hidden_dim, vit.patch_size**2 * 3) @@ -30,9 +29,7 @@ def __init__(self): def forward_encoder(self, images, batch_size, idx_mask): # pass all the tokens to the encoder, both masked and non masked ones - tokens = self.backbone.images_to_tokens(images, prepend_class_token=True) - tokens_masked = utils.mask_at_index(tokens, idx_mask, self.mask_token) - return self.backbone.encoder(tokens_masked) + return self.backbone.encode(images=images, idx_mask=idx_mask) def forward_decoder(self, x_encoded): return self.decoder(x_encoded) diff --git a/examples/pytorch_lightning_distributed/mae.py b/examples/pytorch_lightning_distributed/mae.py index c225b970c..e3b6ccfec 100644 --- a/examples/pytorch_lightning_distributed/mae.py +++ b/examples/pytorch_lightning_distributed/mae.py @@ -1,14 +1,14 @@ # Note: The model and training settings do not follow the reference settings # from the paper. The settings are chosen such that the example can easily be # run on a small dataset with a single GPU. - import pytorch_lightning as pl import torch import torchvision +from timm.models.vision_transformer import vit_base_patch32_224 from torch import nn from lightly.models import utils -from lightly.models.modules import masked_autoencoder +from lightly.models.modules import MAEDecoderTIMM, MaskedVisionTransformerTIMM from lightly.transforms.mae_transform import MAETransform @@ -17,34 +17,33 @@ def __init__(self): super().__init__() decoder_dim = 512 - vit = torchvision.models.vit_b_32(pretrained=False) + vit = vit_base_patch32_224() self.mask_ratio = 0.75 - self.patch_size = vit.patch_size - self.sequence_length = vit.seq_length - self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_dim)) - self.backbone = masked_autoencoder.MAEBackbone.from_vit(vit) - self.decoder = masked_autoencoder.MAEDecoder( - seq_length=vit.seq_length, - num_layers=1, - num_heads=16, - embed_input_dim=vit.hidden_dim, - hidden_dim=decoder_dim, - mlp_dim=decoder_dim * 4, - out_dim=vit.patch_size**2 * 3, - dropout=0, - attention_dropout=0, + self.patch_size = vit.patch_embed.patch_size[0] + self.backbone = MaskedVisionTransformerTIMM(vit=vit) + self.sequence_length = self.backbone.sequence_length + self.decoder = MAEDecoderTIMM( + num_patches=vit.patch_embed.num_patches, + patch_size=self.patch_size, + embed_dim=vit.embed_dim, + decoder_embed_dim=decoder_dim, + decoder_depth=1, + decoder_num_heads=16, + mlp_ratio=4.0, + proj_drop_rate=0.0, + attn_drop_rate=0.0, ) self.criterion = nn.MSELoss() def forward_encoder(self, images, idx_keep=None): - return self.backbone.encode(images, idx_keep) + return self.backbone.encode(images=images, idx_keep=idx_keep) def forward_decoder(self, x_encoded, idx_keep, idx_mask): # build decoder input batch_size = x_encoded.shape[0] x_decode = self.decoder.embed(x_encoded) x_masked = utils.repeat_token( - self.mask_token, (batch_size, self.sequence_length) + self.decoder.mask_token, (batch_size, self.sequence_length) ) x_masked = utils.set_at_index(x_masked, idx_keep, x_decode.type_as(x_masked)) @@ -65,8 +64,10 @@ def training_step(self, batch, batch_idx): mask_ratio=self.mask_ratio, device=images.device, ) - x_encoded = self.forward_encoder(images, idx_keep) - x_pred = self.forward_decoder(x_encoded, idx_keep, idx_mask) + x_encoded = self.forward_encoder(images=images, idx_keep=idx_keep) + x_pred = self.forward_decoder( + x_encoded=x_encoded, idx_keep=idx_keep, idx_mask=idx_mask + ) # get image patches for masked tokens patches = utils.patchify(images, self.patch_size) diff --git a/examples/pytorch_lightning_distributed/msn.py b/examples/pytorch_lightning_distributed/msn.py index 4d62c80c8..8b5c565ae 100644 --- a/examples/pytorch_lightning_distributed/msn.py +++ b/examples/pytorch_lightning_distributed/msn.py @@ -10,8 +10,8 @@ from lightly.loss import MSNLoss from lightly.models import utils +from lightly.models.modules import MaskedVisionTransformerTorchvision from lightly.models.modules.heads import MSNProjectionHead -from lightly.models.modules.masked_autoencoder import MAEBackbone from lightly.transforms.msn_transform import MSNTransform @@ -21,7 +21,8 @@ def __init__(self): # ViT small configuration (ViT-S/16) self.mask_ratio = 0.15 - self.backbone = MAEBackbone( + # ViT small configuration (ViT-S/16) + vit = torchvision.models.VisionTransformer( image_size=224, patch_size=16, num_layers=12, @@ -29,6 +30,7 @@ def __init__(self): hidden_dim=384, mlp_dim=384 * 4, ) + self.backbone = MaskedVisionTransformerTorchvision(vit=vit) # or use a torchvision ViT backbone: # vit = torchvision.models.vit_b_32(pretrained=False) # self.backbone = MAEBackbone.from_vit(vit) @@ -55,7 +57,7 @@ def training_step(self, batch, batch_idx): anchors = views[1] anchors_focal = torch.concat(views[2:], dim=0) - targets_out = self.backbone(targets) + targets_out = self.backbone(images=targets) targets_out = self.projection_head(targets_out) anchors_out = self.encode_masked(anchors) anchors_focal_out = self.encode_masked(anchors_focal) @@ -66,13 +68,13 @@ def training_step(self, batch, batch_idx): def encode_masked(self, anchors): batch_size, _, _, width = anchors.shape - seq_length = (width // self.anchor_backbone.patch_size) ** 2 + seq_length = (width // self.anchor_backbone.vit.patch_size) ** 2 idx_keep, _ = utils.random_token_mask( size=(batch_size, seq_length), mask_ratio=self.mask_ratio, device=self.device, ) - out = self.anchor_backbone(anchors, idx_keep) + out = self.anchor_backbone(images=anchors, idx_keep=idx_keep) return self.anchor_projection_head(out) def configure_optimizers(self): diff --git a/examples/pytorch_lightning_distributed/pmsn.py b/examples/pytorch_lightning_distributed/pmsn.py index a3ae1a7c9..7af776096 100644 --- a/examples/pytorch_lightning_distributed/pmsn.py +++ b/examples/pytorch_lightning_distributed/pmsn.py @@ -10,8 +10,8 @@ from lightly.loss import PMSNLoss from lightly.models import utils +from lightly.models.modules import MaskedVisionTransformerTorchvision from lightly.models.modules.heads import MSNProjectionHead -from lightly.models.modules.masked_autoencoder import MAEBackbone from lightly.transforms import MSNTransform @@ -21,7 +21,7 @@ def __init__(self): # ViT small configuration (ViT-S/16) self.mask_ratio = 0.15 - self.backbone = MAEBackbone( + vit = torchvision.models.VisionTransformer( image_size=224, patch_size=16, num_layers=12, @@ -29,6 +29,7 @@ def __init__(self): hidden_dim=384, mlp_dim=384 * 4, ) + self.backbone = MaskedVisionTransformerTorchvision(vit=vit) # or use a torchvision ViT backbone: # vit = torchvision.models.vit_b_32(pretrained=False) # self.backbone = MAEBackbone.from_vit(vit) @@ -55,7 +56,7 @@ def training_step(self, batch, batch_idx): anchors = views[1] anchors_focal = torch.concat(views[2:], dim=0) - targets_out = self.backbone(targets) + targets_out = self.backbone(images=targets) targets_out = self.projection_head(targets_out) anchors_out = self.encode_masked(anchors) anchors_focal_out = self.encode_masked(anchors_focal) @@ -66,13 +67,13 @@ def training_step(self, batch, batch_idx): def encode_masked(self, anchors): batch_size, _, _, width = anchors.shape - seq_length = (width // self.anchor_backbone.patch_size) ** 2 + seq_length = (width // self.anchor_backbone.vit.patch_size) ** 2 idx_keep, _ = utils.random_token_mask( size=(batch_size, seq_length), mask_ratio=self.mask_ratio, device=self.device, ) - out = self.anchor_backbone(anchors, idx_keep) + out = self.anchor_backbone(images=anchors, idx_keep=idx_keep) return self.anchor_projection_head(out) def configure_optimizers(self): diff --git a/examples/pytorch_lightning_distributed/simmim.py b/examples/pytorch_lightning_distributed/simmim.py index fc92fec8c..3cfb6f028 100644 --- a/examples/pytorch_lightning_distributed/simmim.py +++ b/examples/pytorch_lightning_distributed/simmim.py @@ -4,7 +4,7 @@ from torch import nn from lightly.models import utils -from lightly.models.modules import masked_autoencoder +from lightly.models.modules import MaskedVisionTransformerTorchvision from lightly.transforms.mae_transform import MAETransform # Same transform as MAE @@ -12,27 +12,23 @@ class SimMIM(pl.LightningModule): def __init__(self): super().__init__() - decoder_dim = vit.hidden_dim vit = torchvision.models.vit_b_32(pretrained=False) self.mask_ratio = 0.75 self.patch_size = vit.patch_size self.sequence_length = vit.seq_length - self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_dim)) + decoder_dim = vit.hidden_dim - # same backbone as MAE - self.backbone = masked_autoencoder.MAEBackbone.from_vit(vit) + self.backbone = MaskedVisionTransformerTorchvision(vit=vit) # the decoder is a simple linear layer - self.decoder = nn.Linear(vit.hidden_dim, vit.patch_size**2 * 3) + self.decoder = nn.Linear(decoder_dim, vit.patch_size**2 * 3) # L1 loss as paper suggestion self.criterion = nn.L1Loss() def forward_encoder(self, images, batch_size, idx_mask): # pass all the tokens to the encoder, both masked and non masked ones - tokens = self.backbone.images_to_tokens(images, prepend_class_token=True) - tokens_masked = utils.mask_at_index(tokens, idx_mask, self.mask_token) - return self.backbone.encoder(tokens_masked) + return self.backbone.encode(images=images, idx_mask=idx_mask) def forward_decoder(self, x_encoded): return self.decoder(x_encoded) diff --git a/lightly/models/modules/__init__.py b/lightly/models/modules/__init__.py index 41237ba0e..d5c267b46 100644 --- a/lightly/models/modules/__init__.py +++ b/lightly/models/modules/__init__.py @@ -37,9 +37,17 @@ MAEDecoder, MAEEncoder, ) + from lightly.models.modules.masked_vision_transformer_torchvision import ( + MaskedVisionTransformerTorchvision, + ) + if _dependency.timm_vit_available(): # Requires timm >= 0.9.9 from lightly.models.modules.heads_timm import AIMPredictionHead + from lightly.models.modules.masked_autoencoder_timm import MAEDecoderTIMM from lightly.models.modules.masked_causal_vision_transformer import ( MaskedCausalVisionTransformer, ) + from lightly.models.modules.masked_vision_transformer_timm import ( + MaskedVisionTransformerTIMM, + ) diff --git a/lightly/models/modules/masked_autoencoder.py b/lightly/models/modules/masked_autoencoder.py index f1695c4ce..f7a5a2902 100644 --- a/lightly/models/modules/masked_autoencoder.py +++ b/lightly/models/modules/masked_autoencoder.py @@ -6,6 +6,7 @@ import torch import torch.nn as nn +from torch.nn import Linear, Module, Parameter # vision_transformer requires torchvision >= 0.12 from torchvision.models import vision_transformer @@ -61,18 +62,33 @@ def __init__( attention_dropout=attention_dropout, norm_layer=norm_layer, ) + self._initialize_weights() @classmethod - def from_vit_encoder(cls, vit_encoder: vision_transformer.Encoder) -> MAEEncoder: - """Creates a MAEEncoder from a torchvision ViT encoder.""" + def from_vit_encoder( + cls, vit_encoder: vision_transformer.Encoder, initialize_weights: bool = True + ) -> MAEEncoder: + """Creates a MAEEncoder from a torchvision ViT encoder. + + Args: + vit_encoder: + A torchvision ViT encoder. + initialize_weights: + If True, then all weights are initialized as in MAE paper. Set this to + False if vit_encoder is pretrained. + + Returns: + A MAEEncoder with the same architecture as vit_encoder. + + """ # Create a new instance with dummy values as they will be overwritten # by the copied vit_encoder attributes encoder = cls( - seq_length=1, - num_layers=1, - num_heads=1, - hidden_dim=1, - mlp_dim=1, + seq_length=197, + num_layers=12, + num_heads=12, + hidden_dim=768, + mlp_dim=3072, dropout=0, attention_dropout=0, ) @@ -80,6 +96,8 @@ def from_vit_encoder(cls, vit_encoder: vision_transformer.Encoder) -> MAEEncoder encoder.dropout = vit_encoder.dropout encoder.layers = vit_encoder.layers encoder.ln = vit_encoder.ln + if initialize_weights: + encoder._initialize_weights() return encoder def forward( @@ -133,6 +151,10 @@ def interpolate_pos_encoding(self, input: torch.Tensor): pos_embedding = pos_embedding.permute(0, 2, 3, 1).view(1, -1, dim) return torch.cat((class_emb.unsqueeze(0), pos_embedding), dim=1) + def _initialize_weights(self) -> None: + _initialize_2d_sine_cosine_positional_embedding(self.pos_embedding) + _initialize_linear_layers(self) + class MAEBackbone(vision_transformer.VisionTransformer): """Backbone for the Masked Autoencoder model [0]. @@ -219,8 +241,22 @@ def __init__( ) @classmethod - def from_vit(cls, vit: vision_transformer.VisionTransformer) -> MAEBackbone: - """Creates a MAEBackbone from a torchvision ViT model.""" + def from_vit( + cls, vit: vision_transformer.VisionTransformer, initialize_weights: bool = True + ) -> MAEBackbone: + """Creates a MAEBackbone from a torchvision ViT model. + + Args: + vit: + A torchvision ViT model. + initialize_weights: + If True, then all weights are initialized as in MAE paper. Set this to + False if vit is pretrained. + + Returns: + A MAEBackbone with the same architecture as vit. + + """ # Create a new instance with dummy values as they will be overwritten # by the copied vit_encoder attributes backbone = cls( @@ -240,7 +276,9 @@ def from_vit(cls, vit: vision_transformer.VisionTransformer) -> MAEBackbone: backbone.class_token = vit.class_token backbone.seq_length = vit.seq_length backbone.heads = vit.heads - backbone.encoder = MAEEncoder.from_vit_encoder(vit.encoder) + backbone.encoder = MAEEncoder.from_vit_encoder( + vit.encoder, initialize_weights=initialize_weights + ) return backbone def forward( @@ -307,6 +345,18 @@ def images_to_tokens( tokens = utils.prepend_class_token(tokens, self.class_token) return tokens + def _initialize_weights(self) -> None: + # Initialize the patch embedding layer like a linear layer instead of conv + # layer. + w = self.conv_proj.weight.data + torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + + # Initialize the class token. + torch.nn.init.normal_(self.class_token, std=0.02) + + self.encoder._initialize_weights() + _initialize_linear_layers(self) + class MAEDecoder(vision_transformer.Encoder): """Decoder for the Masked Autoencoder model [0]. @@ -366,6 +416,7 @@ def __init__( ) self.decoder_embed = nn.Linear(embed_input_dim, hidden_dim, bias=True) self.prediction_head = nn.Linear(hidden_dim, out_dim) + self._initialize_weights() def forward(self, input: torch.Tensor) -> torch.Tensor: """Returns predicted pixel values from encoded tokens. @@ -429,3 +480,32 @@ def predict(self, input: torch.Tensor) -> torch.Tensor: """ return self.prediction_head(input) + + def _initialize_weights(self) -> None: + _initialize_2d_sine_cosine_positional_embedding(self.pos_embedding) + _initialize_linear_layers(self) + + +def _initialize_2d_sine_cosine_positional_embedding(pos_embedding: Parameter) -> None: + _, seq_length, hidden_dim = pos_embedding.shape + grid_size = int((seq_length - 1) ** 0.5) + sine_cosine_embedding = utils.get_2d_sine_cosine_positional_embedding( + embed_dim=hidden_dim, + grid_size=grid_size, + cls_token=True, + ) + pos_embedding.data.copy_( + torch.from_numpy(sine_cosine_embedding).float().unsqueeze(0) + ) + # Freeze positional embedding. + pos_embedding.requires_grad = False + + +def _initialize_linear_layers(module: Module) -> None: + def init(mod: Module) -> None: + if isinstance(mod, Linear): + nn.init.xavier_uniform_(mod.weight) + if mod.bias is not None: + nn.init.constant_(mod.bias, 0) + + module.apply(init) diff --git a/lightly/models/modules/masked_autoencoder_timm.py b/lightly/models/modules/masked_autoencoder_timm.py new file mode 100644 index 000000000..435ed8cf6 --- /dev/null +++ b/lightly/models/modules/masked_autoencoder_timm.py @@ -0,0 +1,182 @@ +from functools import partial +from typing import Callable, Optional + +import torch +import torch.nn as nn +from timm.models.vision_transformer import Block +from torch import Tensor +from torch.nn import LayerNorm, Linear, Module, Parameter, Sequential + +from lightly.models import utils + + +class MAEDecoderTIMM(Module): + """Decoder for the Masked Autoencoder model [0]. + + Decodes encoded patches and predicts pixel values for every patch. + Code inspired by [1]. + + - [0]: Masked Autoencoder, 2021, https://arxiv.org/abs/2111.06377 + - [1]: https://github.com/facebookresearch/mae + + Attributes: + num_patches: + Number of patches. + patch_size: + Patch size. + in_chans: + Number of image input channels. + embed_dim: + Embedding dimension of the encoder. + decoder_embed_dim: + Embedding dimension of the decoder. + decoder_depth: + Depth of transformer. + decoder_num_heads: + Number of attention heads. + mlp_ratio: + Ratio of mlp hidden dim to embedding dim. + proj_drop_rate: + Percentage of elements set to zero after the MLP in the transformer. + attn_drop_rate: + Percentage of elements set to zero after the attention head. + norm_layer: + Normalization layer. + mask_token: + The mask token. + + """ + + def __init__( + self, + num_patches: int, + patch_size: int, + in_chans: int = 3, + embed_dim: int = 1024, + decoder_embed_dim: int = 512, + decoder_depth: int = 8, + decoder_num_heads: int = 16, + mlp_ratio: float = 4.0, + proj_drop_rate: float = 0.0, + attn_drop_rate: float = 0.0, + norm_layer: Callable[..., nn.Module] = partial(LayerNorm, eps=1e-6), + mask_token: Optional[Parameter] = None, + ): + super().__init__() + + self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True) + self.mask_token = ( + nn.Parameter(torch.zeros(1, 1, decoder_embed_dim)) + if mask_token is None + else mask_token + ) + + # positional encoding of the decoder + self.decoder_pos_embed = nn.Parameter( + torch.zeros(1, num_patches + 1, decoder_embed_dim), requires_grad=False + ) # fixed sin-cos embedding + + self.decoder_blocks = Sequential( + *[ + Block( + decoder_embed_dim, + decoder_num_heads, + mlp_ratio, + qkv_bias=True, + norm_layer=norm_layer, + proj_drop=proj_drop_rate, + attn_drop=attn_drop_rate, + ) + for i in range(decoder_depth) + ] + ) + + self.decoder_norm = norm_layer(decoder_embed_dim) + self.decoder_pred = nn.Linear( + decoder_embed_dim, patch_size**2 * in_chans, bias=True + ) # decoder to patch + + self._initialize_weights() + + def forward(self, input: Tensor) -> Tensor: + """Returns predicted pixel values from encoded tokens. + + Args: + input: + Tensor with shape (batch_size, seq_length, embed_input_dim). + + Returns: + Tensor with shape (batch_size, seq_length, out_dim). + + """ + out = self.embed(input) + out = self.decode(out) + return self.predict(out) + + def embed(self, input: Tensor) -> Tensor: + """Embeds encoded input tokens into decoder token dimension. + + This is a single linear layer that changes the token dimension from + embed_input_dim to hidden_dim. + + Args: + input: + Tensor with shape (batch_size, seq_length, embed_input_dim) + containing the encoded tokens. + + Returns: + Tensor with shape (batch_size, seq_length, hidden_dim) containing + the embedded tokens. + + """ + out: Tensor = self.decoder_embed(input) + return out + + def decode(self, input: Tensor) -> Tensor: + """Forward pass through the decoder transformer. + + Args: + input: + Tensor with shape (batch_size, seq_length, hidden_dim) containing + the encoded tokens. + + Returns: + Tensor with shape (batch_size, seq_length, hidden_dim) containing + the decoded tokens. + + """ + output: Tensor = input + self.decoder_pos_embed + output = self.decoder_blocks(output) + output = self.decoder_norm(output) + return output + + def predict(self, input: Tensor) -> Tensor: + """Predics pixel values from decoded tokens. + + Args: + input: + Tensor with shape (batch_size, seq_length, hidden_dim) containing + the decoded tokens. + + Returns: + Tensor with shape (batch_size, seq_length, out_dim) containing + predictions for each token. + + """ + out: Tensor = self.decoder_pred(input) + return out + + def _initialize_weights(self) -> None: + torch.nn.init.normal_(self.mask_token, std=0.02) + utils.initialize_2d_sine_cosine_positional_embedding(self.decoder_pos_embed) + self.apply(_init_weights) + + +def _init_weights(module: Module) -> None: + if isinstance(module, Linear): + nn.init.xavier_uniform_(module.weight) + if isinstance(module, Linear) and module.bias is not None: + nn.init.constant_(module.bias, 0) + elif isinstance(module, LayerNorm): + nn.init.constant_(module.bias, 0) + nn.init.constant_(module.weight, 1.0) diff --git a/lightly/models/modules/masked_vision_transformer.py b/lightly/models/modules/masked_vision_transformer.py new file mode 100644 index 000000000..566af7b9e --- /dev/null +++ b/lightly/models/modules/masked_vision_transformer.py @@ -0,0 +1,35 @@ +from abc import ABC, abstractmethod +from typing import Optional + +from torch import Tensor + + +class MaskedVisionTransformer(ABC): + """ + Abstract base class for Masked Vision Transformer models. + + Defines the interface for a Masked Vision Transformer. This class includes abstract + methods that must be implemented by concrete subclasses to define the forward pass, + tokenization of images, and various operations needed for the transformer. + """ + + @abstractmethod + def forward( + self, + images: Tensor, + idx_mask: Optional[Tensor] = None, + idx_keep: Optional[Tensor] = None, + ) -> Tensor: + pass + + @abstractmethod + def images_to_tokens(self, images: Tensor) -> Tensor: + pass + + @abstractmethod + def add_prefix_tokens(self, x: Tensor) -> Tensor: + pass + + @abstractmethod + def add_pos_embed(self, x: Tensor) -> Tensor: + pass diff --git a/lightly/models/modules/masked_vision_transformer_timm.py b/lightly/models/modules/masked_vision_transformer_timm.py new file mode 100644 index 000000000..446f339c3 --- /dev/null +++ b/lightly/models/modules/masked_vision_transformer_timm.py @@ -0,0 +1,235 @@ +import math +from typing import Optional + +import torch +import torch.nn as nn +from timm.layers.pos_embed import resample_abs_pos_embed +from timm.models.vision_transformer import VisionTransformer +from torch import Tensor +from torch.nn import LayerNorm, Linear, Module, Parameter + +from lightly.models import utils +from lightly.models.modules.masked_vision_transformer import MaskedVisionTransformer + + +class MaskedVisionTransformerTIMM(MaskedVisionTransformer, Module): + """Masked Vision Transformer class using TIMM. + + Attributes: + vit: + The VisionTransformer object of TIMM. + mask_token: + The mask token. + + """ + + def __init__( + self, + vit: VisionTransformer, + mask_token: Optional[Parameter] = None, + ) -> None: + super().__init__() + self.vit = vit + self.mask_token = ( + mask_token + if mask_token is not None + else Parameter(torch.zeros(1, 1, self.vit.embed_dim)) + ) + self._initialize_weights() + + @property + def sequence_length(self) -> int: + seq_len: int = self.vit.patch_embed.num_patches + self.vit.num_prefix_tokens + return seq_len + + def forward( + self, + images: Tensor, + idx_mask: Optional[Tensor] = None, + idx_keep: Optional[Tensor] = None, + ) -> Tensor: + """Returns encoded class tokens from a batch of images. + + Args: + images: + Tensor with shape (batch_size, channels, image_size, image_size). + idx_mask: + Tensor with shape (batch_size, num_tokens_to_mask) where each + entry is an index of the token to mask in the respective batch. + If specified, the indexed tokens are masked with self.mask_token. + idx_keep: + Tensor with shape (batch_size, num_tokens_to_keep) where each + entry is an index of the token to keep in the respective batch. + If specified, only the indexed tokens will be passed to the + encoder. + + Returns: + Tensor with shape (batch_size, vit.embed_dim) containing the + encoded class token for every image. + + """ + x = self.encode(images, idx_mask=idx_mask, idx_keep=idx_keep) + if self.vit.attn_pool is not None: + x = self.vit.attn_pool(x) + elif self.vit.global_pool == "avg": + x = x[:, self.vit.num_prefix_tokens :].mean(dim=1) + elif self.vit.global_pool: + x = x[:, 0] # class token + return x + + def encode( + self, + images: Tensor, + idx_mask: Optional[Tensor] = None, + idx_keep: Optional[Tensor] = None, + ) -> Tensor: + """Encode input images. + + Args: + input: + Batch of input images. + idx_mask: + Tensor with shape (batch_size, num_tokens_to_mask) where each + entry is an index of the token to mask in the respective batch. + If specified, the indexed tokens are masked with self.mask_token. + idx_keep: + Tensor with shape (batch_size, num_tokens_to_keep) where each + entry is an index of the token to keep in the respective batch. + If specified, only the indexed tokens will be encoded. + + Returns: + Batch of encoded output tokens. + """ + # convert images to tokens + input = self.images_to_tokens(images) + # add prefix tokens if needed + input = self.add_prefix_tokens(input) + + if idx_mask is not None: + input = utils.mask_at_index(input, idx_mask, self.mask_token) + # add positional encoding + input = self.add_pos_embed(input) + + if idx_keep is not None: + input = utils.get_at_index(input, idx_keep) + # normalization layer + input = self.vit.norm_pre(input) + # apply Transformer blocks + input = self.vit.blocks(input) + # normalize + out: Tensor = self.vit.norm(input) + return out + + def images_to_tokens(self, images: Tensor) -> Tensor: + """Converts images into patch tokens. + + Args: + images: + Tensor with shape (batch_size, channels, image_size, image_size). + + Returns: + Tensor with shape (batch_size, vit.patch_embed.num_patches, vit.embed_dim) + containing the patch tokens (excluding prefix tokens). + """ + tokens: Tensor = self.vit.patch_embed(images) + if self.vit.dynamic_img_size: + tokens = tokens.permute(0, 3, 1, 2) # NHWC -> NCHW + tokens = tokens.flatten(2).transpose(1, 2) # NCHW -> NLC + return tokens + + def add_prefix_tokens(self, x: Tensor) -> Tensor: + """Adds prefix tokens to image patch tokens. + + Args: + x: + Tensor with shape (batch_size, vit.patch_embed.num_patches, vit.embed_dim) + containing the image patch tokens + + Returns: + Tensor with shape (batch_size, self.sequence_length, vit.embed_dim) containing + the image patch tokens and prefix tokens. + """ + prefix_tokens = [] + if self.vit.cls_token is not None: + prefix_tokens.append(self.vit.cls_token.expand(x.shape[0], -1, -1)) + if self.vit.reg_token is not None: + prefix_tokens.append(self.vit.reg_token.expand(x.shape[0], -1, -1)) + if prefix_tokens: + x = torch.cat(prefix_tokens + [x], dim=1) + return x + + def add_pos_embed(self, x: Tensor) -> Tensor: + """Adds positional embeddings to the input tensor based on the Vision Transformer + (ViT) architecture in vit. + + Args: + x: + Input tensor with shape (batch_size, self.sequence_length, vit.embed_dim). + + Returns: + Tensor after adding positional embeddings, with the same shape as the input. + """ + + x_prefix = x[:, : self.vit.num_prefix_tokens, :] + x = x[:, self.vit.num_prefix_tokens :, :] + if self.vit.dynamic_img_size: + x = x.transpose(1, 2) # NLC -> NCL + total_size = torch.numel(x) + batch_size = x.size(0) + num_channels = x.size(1) + grid_size = int(math.sqrt(total_size / (batch_size * num_channels))) + x = x.view( + x.size(0), + x.size(1), + grid_size, + grid_size, + ) # NCL -> NCHW + + # NCHW -> NHWC + x = x.permute(0, 2, 3, 1) + B, H, W, C = x.shape + pos_embed = resample_abs_pos_embed( + self.vit.pos_embed, + (H, W), + num_prefix_tokens=0 + if self.vit.no_embed_class + else self.vit.num_prefix_tokens, + ) + x = x.view(B, -1, C) + else: + pos_embed = self.vit.pos_embed + + if self.vit.no_embed_class: + x = x + pos_embed + if self.vit.num_prefix_tokens: + x = torch.cat((x_prefix, x), dim=1) + else: + if self.vit.num_prefix_tokens: + x = torch.cat((x_prefix, x), dim=1) + x = x + pos_embed + out: Tensor = self.vit.pos_drop(x) + return out + + def _initialize_weights(self) -> None: + # Initialize the patch embedding layer like a linear layer instead of conv + # layer. + w = self.vit.patch_embed.proj.weight.data + torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + + # Initialize the class token. + torch.nn.init.normal_(self.vit.cls_token, std=0.02) + + # initialize nn.Linear and nn.LayerNorm + self.apply(_init_weights) + + utils.initialize_2d_sine_cosine_positional_embedding(self.vit.pos_embed) + + +def _init_weights(module: Module) -> None: + if isinstance(module, Linear): + nn.init.xavier_uniform_(module.weight) + if isinstance(module, Linear) and module.bias is not None: + nn.init.constant_(module.bias, 0) + elif isinstance(module, LayerNorm): + nn.init.constant_(module.bias, 0) + nn.init.constant_(module.weight, 1.0) diff --git a/lightly/models/modules/masked_vision_transformer_torchvision.py b/lightly/models/modules/masked_vision_transformer_torchvision.py new file mode 100644 index 000000000..dcb7f1a05 --- /dev/null +++ b/lightly/models/modules/masked_vision_transformer_torchvision.py @@ -0,0 +1,219 @@ +import math +from typing import Optional + +import torch +import torch.nn as nn +from torch import Tensor +from torch.nn import Linear, Module, Parameter +from torchvision.models.vision_transformer import VisionTransformer + +from lightly.models import utils +from lightly.models.modules.masked_vision_transformer import MaskedVisionTransformer + + +class MaskedVisionTransformerTorchvision(MaskedVisionTransformer, Module): + """Masked Vision Transformer class using Torchvision. + + Attributes: + vit: + The VisionTransformer object of Torchvision. + mask_token: + The mask token. + + """ + + def __init__( + self, + vit: VisionTransformer, + mask_token: Optional[Parameter] = None, + ) -> None: + super().__init__() + self.vit = vit + self.mask_token = ( + mask_token + if mask_token is not None + else Parameter(torch.zeros(1, 1, self.vit.hidden_dim)) + ) + self._initialize_weights() + + @property + def sequence_length(self) -> int: + seq_len: int = self.vit.seq_length + return seq_len + + def interpolate_pos_encoding(self, input: Tensor) -> Tensor: + """Returns the interpolated positional embedding for the given input. + + This function interpolates self.pos_embedding for all tokens in the input, + ignoring the class token. This allows encoding variable sized images. + + Args: + input: + Input tensor with shape (batch_size, num_sequences). + + """ + # code copied from: + # https://github.com/facebookresearch/msn/blob/4388dc1eadbe3042b85d3296d41b9b207656e043/src/deit.py#L291 + npatch = input.shape[1] - 1 + N = self.vit.encoder.pos_embedding.shape[1] - 1 + if npatch == N: + pos_embedding: Tensor = self.vit.encoder.pos_embedding + return pos_embedding + class_emb = self.vit.encoder.pos_embedding[:, 0] + pos_embedding = self.vit.encoder.pos_embedding[:, 1:] + dim = input.shape[-1] + pos_embedding = nn.functional.interpolate( + pos_embedding.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute( + 0, 3, 1, 2 + ), + scale_factor=math.sqrt(npatch / N), + mode="bicubic", + ) + pos_embedding = pos_embedding.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_emb.unsqueeze(0), pos_embedding), dim=1) + + def forward( + self, + images: Tensor, + idx_mask: Optional[Tensor] = None, + idx_keep: Optional[Tensor] = None, + ) -> Tensor: + """Returns encoded class tokens from a batch of images. + + Args: + images: + Tensor with shape (batch_size, channels, image_size, image_size). + idx_mask: + Tensor with shape (batch_size, num_tokens_to_mask) where each + entry is an index of the token to mask in the respective batch. + If specified, the indexed tokens are masked with self.mask_token. + idx_keep: + Tensor with shape (batch_size, num_tokens_to_keep) where each + entry is an index of the token to keep in the respective batch. + If specified, only the indexed tokens will be passed to the + encoder. + + Returns: + Tensor with shape (batch_size, vit.hidden_dim) containing the + encoded class token for every image. + + """ + out = self.encode(images, idx_mask=idx_mask, idx_keep=idx_keep) + class_token = out[:, 0] + return class_token + + def encode( + self, + images: Tensor, + idx_mask: Optional[Tensor] = None, + idx_keep: Optional[Tensor] = None, + ) -> Tensor: + """Encode input images. + + Args: + input: + Batch of input images. + idx_mask: + Tensor with shape (batch_size, num_tokens_to_mask) where each + entry is an index of the token to mask in the respective batch. + If specified, the indexed tokens are masked with self.mask_token. + idx_keep: + Tensor with shape (batch_size, num_tokens_to_keep) where each + entry is an index of the token to keep in the respective batch. + If specified, only the indexed tokens will be encoded. + + Returns: + Batch of encoded output tokens. + """ + # convert images to tokens + input = self.images_to_tokens(images) + # add prefix tokens if needed + input = self.add_prefix_tokens(input) + + if idx_mask is not None: + input = utils.mask_at_index(input, idx_mask, self.mask_token) + # add positional encoding + input = self.add_pos_embed(input) + + if idx_keep is not None: + input = utils.get_at_index(input, idx_keep) + out: Tensor = self.vit.encoder.ln( + self.vit.encoder.layers(self.vit.encoder.dropout(input)) + ) + return out + + def images_to_tokens(self, images: Tensor) -> Tensor: + """Converts images into patch tokens. + + Args: + images: + Tensor with shape (batch_size, channels, image_size, image_size). + + Returns: + Tensor with shape (batch_size, vit.seq_length-1, vit.hidden_dim) containing + the image patch tokens. + """ + x = self.vit.conv_proj(images) + tokens: Tensor = x.flatten(2).transpose(1, 2) + return tokens + + def add_prefix_tokens(self, x: Tensor, prepend_class_token: bool = True) -> Tensor: + """Adds class token to image patch tokens. + + Args: + x: + Tensor with shape (batch_size, vit.seq_length-1, vit.hidden_dim) + containing the image patch tokens + prepend_class_token: + Boolean flag that determines if a class token should be prepended. + + Returns: + Tensor with shape (batch_size, vit.seq_length, vit.hidden_dim) containing + the image patch tokens and class tokens. + """ + if prepend_class_token: + x = utils.prepend_class_token(x, self.vit.class_token) + return x + + def add_pos_embed(self, x: Tensor) -> Tensor: + """Adds positional embeddings to the input tensor based on the Vision Transformer + (ViT) architecture in vit. + + Args: + x: + Input tensor with shape (batch_size, self.sequence_length, vit.hidden_dim). + + Returns: + Tensor after adding positional embeddings, with the same shape as the input. + """ + # TODO(Ersi:1/24) This adds positional encoding to the prefix tokens as well. + # Give the option of not doing so, as is the case for TIMM. + x = x + self.interpolate_pos_encoding(x) + return x + + def _initialize_weights(self) -> None: + # Initialize the patch embedding layer like a linear layer instead of conv + # layer. + w = self.vit.conv_proj.weight.data + torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + + # Initialize the class token. + torch.nn.init.normal_(self.vit.class_token, std=0.02) + + # Initialize positional encoding. + utils.initialize_2d_sine_cosine_positional_embedding( + self.vit.encoder.pos_embedding + ) + + # Initialize linear layers. + _initialize_linear_layers(self) + + +def _initialize_linear_layers(module: Module) -> None: + def init(mod: Module) -> None: + if isinstance(mod, Linear): + nn.init.xavier_uniform_(mod.weight) + if mod.bias is not None: + nn.init.constant_(mod.bias, 0) + + module.apply(init) diff --git a/lightly/models/utils.py b/lightly/models/utils.py index d5a22ea96..ce5febc10 100644 --- a/lightly/models/utils.py +++ b/lightly/models/utils.py @@ -638,6 +638,21 @@ def add_stochastic_depth_to_blocks(vit: Module, prob: float = 0.0, mode="row") - mod.mlp = Sequential(mod.mlp, StochasticDepth(p=prob, mode=mode)) +def initialize_2d_sine_cosine_positional_embedding(pos_embedding: Parameter) -> None: + _, seq_length, hidden_dim = pos_embedding.shape + grid_size = int((seq_length - 1) ** 0.5) + sine_cosine_embedding = get_2d_sine_cosine_positional_embedding( + embed_dim=hidden_dim, + grid_size=grid_size, + cls_token=True, + ) + pos_embedding.data.copy_( + torch.from_numpy(sine_cosine_embedding).float().unsqueeze(0) + ) + # Freeze positional embedding. + pos_embedding.requires_grad = False + + def get_2d_sine_cosine_positional_embedding( embed_dim: int, grid_size: int, cls_token: bool ) -> NDArray[np.float32]: diff --git a/tests/models/modules/test_masked_autoencoder.py b/tests/models/modules/test_masked_autoencoder.py index 42210bfdb..0e97041a5 100644 --- a/tests/models/modules/test_masked_autoencoder.py +++ b/tests/models/modules/test_masked_autoencoder.py @@ -87,6 +87,16 @@ def test_forward(self): def test_forward_cuda(self): self._test_forward(torch.device("cuda")) + def test_images_to_tokens(self) -> None: + torch.manual_seed(0) + vit = self._vit() + backbone = MAEBackbone.from_vit(vit) + images = torch.rand(2, 3, 224, 224) + assert torch.all( + vit._process_input(images) + == backbone.images_to_tokens(images, prepend_class_token=False) + ) + @unittest.skipUnless( dependency.torchvision_vit_available(), "Torchvision ViT not available" diff --git a/tests/models/modules/test_masked_autoencoder_timm.py b/tests/models/modules/test_masked_autoencoder_timm.py new file mode 100644 index 000000000..0139db5d8 --- /dev/null +++ b/tests/models/modules/test_masked_autoencoder_timm.py @@ -0,0 +1,108 @@ +import unittest + +import torch + +from lightly.models import utils +from lightly.utils import dependency + +if dependency.timm_vit_available(): + from timm.models.vision_transformer import VisionTransformer, vit_base_patch32_224 + + from lightly.models.modules import MAEDecoderTIMM, MaskedVisionTransformerTIMM + + class TestMaskedVisionTransformerTIMM(unittest.TestCase): + def _vit(self) -> VisionTransformer: + return vit_base_patch32_224() + + def test_from_vit(self) -> None: + MaskedVisionTransformerTIMM(vit=self._vit()) + + def _test_forward( + self, device: torch.device, batch_size: int = 8, seed: int = 0 + ) -> None: + torch.manual_seed(seed) + vit = self._vit() + backbone = MaskedVisionTransformerTIMM(vit=vit).to(device) + images = torch.rand( + batch_size, 3, vit.patch_embed.img_size[0], vit.patch_embed.img_size[0] + ).to(device) + _idx_keep, _ = utils.random_token_mask( + size=(batch_size, backbone.sequence_length), + device=device, + ) + for idx_keep in [None, _idx_keep]: + with self.subTest(idx_keep=idx_keep): + class_tokens = backbone(images=images, idx_keep=idx_keep) + + # output shape must be correct + expected_shape = [batch_size, vit.embed_dim] + self.assertListEqual(list(class_tokens.shape), expected_shape) + + # output must have reasonable numbers + self.assertTrue(torch.all(torch.not_equal(class_tokens, torch.inf))) + + def test_forward(self) -> None: + self._test_forward(torch.device("cpu")) + + @unittest.skipUnless(torch.cuda.is_available(), "Cuda not available.") + def test_forward_cuda(self) -> None: + self._test_forward(torch.device("cuda")) + + def test_images_to_tokens(self) -> None: + torch.manual_seed(0) + vit = self._vit() + backbone = MaskedVisionTransformerTIMM(vit=vit) + images = torch.rand(2, 3, 224, 224) + assert torch.all( + vit.patch_embed(images) == backbone.images_to_tokens(images=images) + ) + + class TestMAEDecoderTIMM(unittest.TestCase): + def test_init(self) -> None: + MAEDecoderTIMM( + num_patches=49, + patch_size=32, + embed_dim=128, + decoder_embed_dim=256, + decoder_depth=2, + decoder_num_heads=4, + mlp_ratio=4.0, + proj_drop_rate=0.0, + attn_drop_rate=0.0, + ) + + def _test_forward( + self, device: torch.device, batch_size: int = 8, seed: int = 0 + ) -> None: + torch.manual_seed(seed) + seq_length = 50 + embed_input_dim = 128 + patch_size = 32 + out_dim = 3 * patch_size**2 + decoder = MAEDecoderTIMM( + num_patches=49, + patch_size=32, + embed_dim=embed_input_dim, + decoder_embed_dim=256, + decoder_depth=2, + decoder_num_heads=4, + mlp_ratio=4.0, + proj_drop_rate=0.0, + attn_drop_rate=0.0, + ).to(device) + tokens = torch.rand(batch_size, seq_length, embed_input_dim).to(device) + predictions = decoder(tokens) + + # output shape must be correct + expected_shape = [batch_size, seq_length, out_dim] + self.assertListEqual(list(predictions.shape), expected_shape) + + # output must have reasonable numbers + self.assertTrue(torch.all(torch.not_equal(predictions, torch.inf))) + + def test_forward(self) -> None: + self._test_forward(torch.device("cpu")) + + @unittest.skipUnless(torch.cuda.is_available(), "Cuda not available.") + def test_forward_cuda(self) -> None: + self._test_forward(torch.device("cuda"))