Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ersi lig 3912 refactor mae to use timm vit #1461

Merged
merged 59 commits into from
Mar 6, 2024
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
59 commits
Select commit Hold shift + click to select a range
951d43e
Add MAE evaluation
guarin May 30, 2023
503cc44
Add stochastic depth dropout
guarin May 31, 2023
ac43499
Add MAE
guarin May 31, 2023
15bfe3a
Drop assertion
guarin May 31, 2023
49c85c0
Fix smooth cross entropy loss and mixup
guarin May 31, 2023
9d95783
Update comments
guarin May 31, 2023
0bb601d
Add layer lr decay and weight decay
guarin Jun 5, 2023
d7d69af
Update comment
guarin Jun 5, 2023
ec05437
Add test for MAE images_to_tokens
guarin Jun 5, 2023
923a606
Disable BN update
guarin Jun 5, 2023
bdce8a6
Add BN before classification head
guarin Jun 6, 2023
316f918
Format
guarin Jun 6, 2023
a6943fd
Fix BN freezing
guarin Jun 6, 2023
1a2b454
Cleanup
guarin Jun 6, 2023
bc066ae
Use torch.no_grad instead of deactivating gradients manually
guarin Jun 6, 2023
d56e340
Create new stochastic depth instances
guarin Jun 6, 2023
5ed6803
Add mask token to learnable params
guarin Jun 6, 2023
4f0baf1
Add sine-cosine positional embedding
guarin Jun 6, 2023
9c4a8cf
Initialize parameters as in paper
guarin Jun 6, 2023
9904c10
Merge branch 'master' into guarin-lig-3056-add-mae-imagenet-benchmark
guarin Dec 6, 2023
83edd1c
Fix types
guarin Dec 6, 2023
e27946e
Format
guarin Dec 6, 2023
0672b0a
Merge branch 'guarin-lig-3056-add-mae-imagenet-benchmark' of github.c…
ersi-lightly Dec 17, 2023
45433c5
adjusted to existing interface
ersi-lightly Dec 18, 2023
c5cab9e
draft
ersi-lightly Dec 19, 2023
017168e
remove
ersi-lightly Dec 19, 2023
278423b
added modifications
ersi-lightly Jan 4, 2024
fde116c
added mae implementation with timm and example
ersi-lightly Jan 5, 2024
f008645
formatted
ersi-lightly Jan 5, 2024
c97112e
fixed import
ersi-lightly Jan 5, 2024
2e55d6b
removed
ersi-lightly Jan 5, 2024
484add1
fixed typing
ersi-lightly Jan 5, 2024
971c19a
addressed comments
ersi-lightly Jan 9, 2024
1ec7470
fixed typing and formatted
ersi-lightly Jan 9, 2024
76ee356
addressed comments
ersi-lightly Jan 9, 2024
edb2d42
added docstring and formatted
ersi-lightly Jan 9, 2024
f00d320
removed images to tokens method
ersi-lightly Jan 10, 2024
cc263fe
Ersi lig 3910 update mae benchmark code (#1468)
ersi-lightly Feb 23, 2024
f7a532b
resolved conflict
ersi-lightly Feb 23, 2024
cc9d4ac
resolved conflicts
ersi-lightly Feb 23, 2024
eda5ac0
formatted
ersi-lightly Feb 23, 2024
0993ec5
adjusted examples
ersi-lightly Feb 25, 2024
4842901
removed comment
ersi-lightly Feb 25, 2024
bc9c6c3
added test
ersi-lightly Feb 26, 2024
73cb2ec
added message in case of ImportError
ersi-lightly Feb 26, 2024
12eeb14
fixed skipping of test
ersi-lightly Feb 26, 2024
ab15124
removed example
ersi-lightly Feb 26, 2024
c22b2ff
handling the TIMM dependency
ersi-lightly Feb 26, 2024
3fb9e86
added note to docs for MAE installation
ersi-lightly Feb 26, 2024
17760fd
added unit tests for MAE with torchvision
ersi-lightly Feb 26, 2024
44cf4e8
removed unecessary maks token definition
ersi-lightly Feb 26, 2024
07fdae4
addressed comments
ersi-lightly Feb 29, 2024
a0f87ac
moved test to separate file
ersi-lightly Feb 29, 2024
f61b708
added typing
ersi-lightly Feb 29, 2024
0f8a927
fixed import
ersi-lightly Feb 29, 2024
a18bccd
fixes typing
ersi-lightly Feb 29, 2024
b16bba8
fixed typing
ersi-lightly Feb 29, 2024
90a2ee0
fixed typing
ersi-lightly Feb 29, 2024
d89808f
Ersi lig 4471 cleanup and merge mae branch (#1510)
ersi-lightly Mar 5, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion benchmarks/imagenet/vitb16/mae.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,19 @@
import sys
from typing import List, Tuple

import torch
from pytorch_lightning import LightningModule
from timm.models.vision_transformer import vit_base_patch16_224
from torch import Tensor
from torch.nn import MSELoss, Parameter
from torch.optim import AdamW

from lightly.utils import dependency

if dependency.timm_vit_available():
from timm.models.vision_transformer import vit_base_patch16_224
else:
sys.exit(1)

from lightly.models import utils
from lightly.models.modules import (
masked_autoencoder_timm,
Expand Down
7 changes: 7 additions & 0 deletions docs/source/getting_started/install.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,13 @@ If you want to work with video files you need to additionally install

pip install av

If you want to work use the Masked Autoencoder you need to additionally install
`TIMM <https://github.com/huggingface/pytorch-image-models>`_.

.. code-block:: bash

pip install timm

Next Steps
------------

Expand Down
44 changes: 23 additions & 21 deletions examples/pytorch/mae.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
# Note: The model and training settings do not follow the reference settings
# from the paper. The settings are chosen such that the example can easily be
# run on a small dataset with a single GPU.

import torch
import torchvision
from timm.models.vision_transformer import vit_base_patch32_224
from torch import nn

from lightly.models import utils
from lightly.models.modules import masked_autoencoder
from lightly.models.modules import MAEDecoderTIMM, MaskedVisionTransformerTIMM
from lightly.transforms.mae_transform import MAETransform


Expand All @@ -17,31 +17,31 @@ def __init__(self, vit):

decoder_dim = 512
self.mask_ratio = 0.75
self.patch_size = vit.patch_size
self.sequence_length = vit.seq_length
self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_dim))
self.backbone = masked_autoencoder.MAEBackbone.from_vit(vit)
self.decoder = masked_autoencoder.MAEDecoder(
seq_length=vit.seq_length,
num_layers=1,
num_heads=16,
embed_input_dim=vit.hidden_dim,
hidden_dim=decoder_dim,
mlp_dim=decoder_dim * 4,
out_dim=vit.patch_size**2 * 3,
dropout=0,
attention_dropout=0,
self.patch_size = vit.patch_embed.patch_size[0]

self.backbone = MaskedVisionTransformerTIMM(vit=vit)
self.sequence_length = self.backbone.sequence_length
self.decoder = MAEDecoderTIMM(
num_patches=vit.patch_embed.num_patches,
patch_size=self.patch_size,
embed_dim=vit.embed_dim,
decoder_embed_dim=decoder_dim,
decoder_depth=1,
decoder_num_heads=16,
mlp_ratio=4.0,
proj_drop_rate=0.0,
attn_drop_rate=0.0,
)

def forward_encoder(self, images, idx_keep=None):
return self.backbone.encode(images, idx_keep)
return self.backbone.encode(images=images, idx_keep=idx_keep)

def forward_decoder(self, x_encoded, idx_keep, idx_mask):
# build decoder input
batch_size = x_encoded.shape[0]
x_decode = self.decoder.embed(x_encoded)
x_masked = utils.repeat_token(
self.mask_token, (batch_size, self.sequence_length)
self.decoder.mask_token, (batch_size, self.sequence_length)
)
x_masked = utils.set_at_index(x_masked, idx_keep, x_decode.type_as(x_masked))

Expand All @@ -60,8 +60,10 @@ def forward(self, images):
mask_ratio=self.mask_ratio,
device=images.device,
)
x_encoded = self.forward_encoder(images, idx_keep)
x_pred = self.forward_decoder(x_encoded, idx_keep, idx_mask)
x_encoded = self.forward_encoder(images=images, idx_keep=idx_keep)
x_pred = self.forward_decoder(
x_encoded=x_encoded, idx_keep=idx_keep, idx_mask=idx_mask
)

# get image patches for masked tokens
patches = utils.patchify(images, self.patch_size)
Expand All @@ -70,7 +72,7 @@ def forward(self, images):
return x_pred, target


vit = torchvision.models.vit_b_32(pretrained=False)
vit = vit_base_patch32_224()
model = MAE(vit)

device = "cuda" if torch.cuda.is_available() else "cpu"
Expand Down
121 changes: 0 additions & 121 deletions examples/pytorch/mae_timm.py

This file was deleted.

10 changes: 5 additions & 5 deletions examples/pytorch/msn.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@

from lightly.loss import MSNLoss
from lightly.models import utils
from lightly.models.modules import MaskedVisionTransformerTorchvision
from lightly.models.modules.heads import MSNProjectionHead
from lightly.models.modules.masked_autoencoder import MAEBackbone
from lightly.transforms.msn_transform import MSNTransform


Expand All @@ -19,7 +19,7 @@ def __init__(self, vit):
super().__init__()

self.mask_ratio = 0.15
self.backbone = MAEBackbone.from_vit(vit)
self.backbone = MaskedVisionTransformerTorchvision(vit=vit)
self.projection_head = MSNProjectionHead(input_dim=384)

self.anchor_backbone = copy.deepcopy(self.backbone)
Expand All @@ -31,18 +31,18 @@ def __init__(self, vit):
self.prototypes = nn.Linear(256, 1024, bias=False).weight

def forward(self, images):
out = self.backbone(images)
out = self.backbone(images=images)
return self.projection_head(out)

def forward_masked(self, images):
batch_size, _, _, width = images.shape
seq_length = (width // self.anchor_backbone.patch_size) ** 2
seq_length = (width // self.anchor_backbone.vit.patch_size) ** 2
idx_keep, _ = utils.random_token_mask(
size=(batch_size, seq_length),
mask_ratio=self.mask_ratio,
device=images.device,
)
out = self.anchor_backbone(images, idx_keep)
out = self.anchor_backbone(images=images, idx_keep=idx_keep)
return self.anchor_projection_head(out)


Expand Down
12 changes: 6 additions & 6 deletions examples/pytorch/pmsn.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@

from lightly.loss import PMSNLoss
from lightly.models import utils
from lightly.models.modules import MaskedVisionTransformerTorchvision
from lightly.models.modules.heads import MSNProjectionHead
from lightly.models.modules.masked_autoencoder import MAEBackbone
from lightly.transforms import MSNTransform


Expand All @@ -19,7 +19,7 @@ def __init__(self, vit):
super().__init__()

self.mask_ratio = 0.15
self.backbone = MAEBackbone.from_vit(vit)
self.backbone = MaskedVisionTransformerTorchvision(vit=vit)
self.projection_head = MSNProjectionHead(384)

self.anchor_backbone = copy.deepcopy(self.backbone)
Expand All @@ -31,18 +31,18 @@ def __init__(self, vit):
self.prototypes = nn.Linear(256, 1024, bias=False).weight

def forward(self, images):
out = self.backbone(images)
out = self.backbone(images=images)
return self.projection_head(out)

def forward_masked(self, images):
batch_size, _, _, width = images.shape
seq_length = (width // self.anchor_backbone.patch_size) ** 2
seq_length = (width // self.anchor_backbone.vit.patch_size) ** 2
idx_keep, _ = utils.random_token_mask(
size=(batch_size, seq_length),
mask_ratio=self.mask_ratio,
device=images.device,
)
out = self.anchor_backbone(images, idx_keep)
out = self.anchor_backbone(images=images, idx_keep=idx_keep)
return self.anchor_projection_head(out)


Expand Down Expand Up @@ -106,7 +106,7 @@ def forward_masked(self, images):
anchors = views[1]
anchors_focal = torch.concat(views[2:], dim=0)

targets_out = model.backbone(targets)
targets_out = model.backbone(images=targets)
targets_out = model.projection_head(targets_out)
anchors_out = model.forward_masked(anchors)
anchors_focal_out = model.forward_masked(anchors_focal)
Expand Down
14 changes: 6 additions & 8 deletions examples/pytorch/simmim.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
from torch import nn

from lightly.models import utils
from lightly.models.modules import masked_autoencoder
from lightly.models.modules.masked_vision_transformer_torchvision import (
MaskedVisionTransformerTorchvision,
)
from lightly.transforms.mae_transform import MAETransform # Same transform as MAE


Expand All @@ -15,19 +17,15 @@ def __init__(self, vit):
self.mask_ratio = 0.75
self.patch_size = vit.patch_size
self.sequence_length = vit.seq_length
self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_dim))

# same backbone as MAE
self.backbone = masked_autoencoder.MAEBackbone.from_vit(vit)
self.backbone = MaskedVisionTransformerTorchvision(vit=vit)

# the decoder is a simple linear layer
self.decoder = nn.Linear(vit.hidden_dim, vit.patch_size**2 * 3)
self.decoder = nn.Linear(decoder_dim, vit.patch_size**2 * 3)

def forward_encoder(self, images, batch_size, idx_mask):
# pass all the tokens to the encoder, both masked and non masked ones
tokens = self.backbone.images_to_tokens(images, prepend_class_token=True)
tokens_masked = utils.mask_at_index(tokens, idx_mask, self.mask_token)
return self.backbone.encoder(tokens_masked)
return self.backbone.encode(images=images, idx_mask=idx_mask)

def forward_decoder(self, x_encoded):
return self.decoder(x_encoded)
Expand Down
Loading
Loading