New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Masked Autoencoder implementation #721
Comments
I would like to work on this, do we have reference code implementations for this ? I am not sure if I will be able to reproduce the results as I do not have that much hardware, but this seems interesting. |
You can check out papers with code. They are referencing this repo by facebook research. It'd be great if you could share your thoughts on how you would integrate it best into the current lightly package structure 🙂 |
@philippmwirth Thanks for the references. I will give the paper a read and then we can discuss ideas how we can integrate it with lightly. This will be fun to do. |
I looked at the code. It seems simple enough. Few things I would like to highlight.
All in all, this seems like a nice implementation for lightly. I can implement it and bring it to the lightly code standards, what I would require help from lightly team is for experimentation(I do not have that much hardware to replicate the results) and I need guidance on how to write tests for this implementation. Please let me know your thought @philippmwirth @IgorSusmelj. |
Hi @Atharva-Phatak, thanks for the summary! That looks great. Augmentations should be used from torchvision whenever possible. Overall, we try to make lightly rather modular. That will make it easier to combine different architectures, training procedures, and loss functions. Btw. torchvision just added lots of new augmentations and vit models. Maybe we could build on top of it? |
We can test the whole implementation on our hardware. |
@IgorSusmelj We can adapt the code from |
Please correct me if I'm wrong but wouldn't it be enough to e.g. inherit from the torchvision MaskedAutoencoderViT(torchvision.models.VisionTransformer):
def forward(self, x: torch.Tensor, mask_ratio: float):
x = self._process_input(x)
n = x.shape[0]
# new: random masking
x, mask, ids_restore = self.random_masking(x, mask_ratio)
# Expand the class token to the full batch
batch_class_token = self.class_token.expand(n, -1, -1)
x = torch.cat([batch_class_token, x], dim=1)
x = self.encoder(x)
# Classifier "token" as used by standard language architectures
x = x[:, 0]
x = self.heads(x)
return x, mask, ids_restore And then we could add modules for the decoder and for the loss to lightly. Sharing the architecture with torchvision has the simple advantage that exporting and importing weights is very convenient. @guarin I think you looked into this before. Anything I'm missing? |
I remember that the masking + positional encoding part was not trivial. Especially because it is needed in the encoder and the decoder. And it did not fit nicely our workflow where we always return (image, target, filename) tuples. But I am sure we can figure it out. I would also use the pytorch vit and first focus on implementing the model without any visualizations. |
@philippmwirth I agree torchvision version looks simple enough. If everyone concurs then I will study the code in torchvision and we can implement MAE accordingly. Please let me know. |
Sounds great @Atharva-Phatak! Let us know if you need support 🙂 |
I am sorry this is taking time from my end as I am busy with my final exams 😭 . I will try to create a PR ASAP. |
No worries, good luck with your exams! |
I've been doing some investigations and somehow this is the best I've come up with on how to make the torchvision ViTs work with Lightly and the MAE setup. @guarin @Atharva-Phatak @IgorSusmelj I'd love to hear your opinions. IMO it's not a very clean approach but I think it should work... If you have better ideas let me know 🙂 Possible approach for an MAE implementation (or rather how to pretrain a # initialize ViT
vit = torchvision.models.vit_b_16(pretrained=False)
# use a lightly MaskedEncoder which inherits from torchvision.models.vision_transformer.Encoder
encoder = lightly.modules.MaskedEncoder.from_encoder(vit.encoder)
# use a lightly MaskedDecoder
decoder = lightly.MaskedDecoder()
# use the loss implemented by lightly
loss = MAELoss()
# pre-training
for i in range(epochs):
for x in dataloader:
# x is a batch of images (bsz, 3, w, h)
# need to process the input (patchify & embed)
x_processed = vit._process_input(x)
# manually add the cls token
n = x_processed.shape[0]
batch_class_token = self.class_token.expand(n, -1, -1)
x _processed= torch.cat([batch_class_token, x_processed], dim=1)
# forward pass encoder
x_encoded, mask, ids_restore = encoder(x_processed)
# forward pass decoder
x_decoded = decoder(x_encoded, mask, ids_restore)
# possibly convert x_decoded to image patches here
# loss calculation
l = loss(x, x_decoded)
# backwards pass etc
# restore original encoder with pretrained weights
vit.encoder = encoder.strip_mask() This would require us to add the following things to lightly:
|
This looks great! Based on your proposal I was able to write the following draft that successfully runs, not sure if it actually works though 🙂 The code is pretty verbose as I have not yet figured out an optimal structure but all the building blocks are there. The main issue is that encoding, masking, and decoding are pretty interleaved and have to share a lot of information between each other. So maybe an overall MAE class that holds the encoder, decoder, class token, and mask token could be a good solution, although this would be a bit against our "low-level" building blocks principle. This code should also be pretty easy to adapt to the SimMIM and SplitMask models. The code is adapted from: https://github.com/facebookresearch/mae
|
@guarin I was implementing Encoder and Decoder as different classes similar to heads and then I was going to combine them in a a single class so that the information sharing will be a bit easier between the modules. I think @guarin seems more elegant and easy to adapt. But that being said I would like to know which approach should I follow ? |
Great draft @guarin! 🙂 To answer both your questions: I'd suggest we start building the low-level blocks first (i.e. Ideally, we'd have a working version of the encoder and decoder relatively soon so we can run a quick benchmark on e.g. Imagenette to see if it works as expected. We can then work on the final implementation together 👍 There might be some differences between the original paper and your implementation, @guarin:
That's not dramatic but we should make sure to note it somewhere. |
Aaah good catch, I didn't notice that! Yes we either have to add a note or can overwrite the positional embedding with a sine-cosine one. Although overwriting would break pretrained vits, so maybe that is not the best idea. Regarding cleanup / code structure:
Next steps would be:
@Atharva-Phatak do you already have some example code/draft? Would be great if we can compare :) |
Hi @guarin I am adding the implementation of encoder. My decoder is same is as yours.
|
The paper Masked Autoencoders Are Scalable Vision Learners
https://arxiv.org/abs/2111.06377 is suggesting that a masked auto-encoder (similar to pre-training on NLP) works very well as a pretext task for self-supervised learning. Let's add it to Lightly.
The text was updated successfully, but these errors were encountered: