Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Masked Autoencoder implementation #721

Closed
IgorSusmelj opened this issue Mar 6, 2022 · 19 comments · Fixed by #799
Closed

Add Masked Autoencoder implementation #721

IgorSusmelj opened this issue Mar 6, 2022 · 19 comments · Fixed by #799
Assignees

Comments

@IgorSusmelj
Copy link
Contributor

The paper Masked Autoencoders Are Scalable Vision Learners
https://arxiv.org/abs/2111.06377 is suggesting that a masked auto-encoder (similar to pre-training on NLP) works very well as a pretext task for self-supervised learning. Let's add it to Lightly.

image

@Atharva-Phatak
Copy link
Contributor

I would like to work on this, do we have reference code implementations for this ? I am not sure if I will be able to reproduce the results as I do not have that much hardware, but this seems interesting.

@philippmwirth
Copy link
Contributor

You can check out papers with code. They are referencing this repo by facebook research.

It'd be great if you could share your thoughts on how you would integrate it best into the current lightly package structure 🙂

@Atharva-Phatak
Copy link
Contributor

@philippmwirth Thanks for the references. I will give the paper a read and then we can discuss ideas how we can integrate it with lightly. This will be fun to do.

@Atharva-Phatak
Copy link
Contributor

Atharva-Phatak commented Mar 16, 2022

I looked at the code. It seems simple enough. Few things I would like to highlight.

  • Augmentations : They use standard transformations available in torchvision.
  • Model Architecture : The repo has their own implmentation of MAE, so I guess we would need to integrate into lightly(they also use some utils from timm). This will add additional install dependency to lightly.
  • Training/Pre-training: The pretraining process is straightforward. A tutorial should be enough.
  • Visualization-Utils: We can add the visualization utilities they have in their demo
  • Criterion/Loss Function: Their MAE implementation returns loss as well, we can create different module for it.

All in all, this seems like a nice implementation for lightly. I can implement it and bring it to the lightly code standards, what I would require help from lightly team is for experimentation(I do not have that much hardware to replicate the results) and I need guidance on how to write tests for this implementation.

Please let me know your thought @philippmwirth @IgorSusmelj.

@IgorSusmelj
Copy link
Contributor Author

Hi @Atharva-Phatak, thanks for the summary!

That looks great.

Augmentations should be used from torchvision whenever possible.
Model Architecture I'd try to avoid adding timm as a dependency. Any thoughts @guarin and @philippmwirth ?
Visualization-Utils we can add this later or one could use the colab from facebook :)
Criterion/Loss Function this should be separate from the model

Overall, we try to make lightly rather modular. That will make it easier to combine different architectures, training procedures, and loss functions.
I guess the key question is how to split up theMaskedAutoencoderViT into good independent pieces.

Btw. torchvision just added lots of new augmentations and vit models. Maybe we could build on top of it?

@IgorSusmelj
Copy link
Contributor Author

We can test the whole implementation on our hardware.

@Atharva-Phatak
Copy link
Contributor

Atharva-Phatak commented Mar 16, 2022

@IgorSusmelj We can adapt the code from timm it seems pretty easy to integrate that. This will remove additional dependency of timm.

@philippmwirth
Copy link
Contributor

Please correct me if I'm wrong but wouldn't it be enough to e.g. inherit from the torchvision ViT implementation (link) and simply override the forward function. Something like this:

MaskedAutoencoderViT(torchvision.models.VisionTransformer):

    def forward(self, x: torch.Tensor, mask_ratio: float):

        x = self._process_input(x)
        n = x.shape[0]
        # new: random masking
        x, mask, ids_restore = self.random_masking(x, mask_ratio)
        

        # Expand the class token to the full batch
        batch_class_token = self.class_token.expand(n, -1, -1)
        x = torch.cat([batch_class_token, x], dim=1)

        x = self.encoder(x)

        # Classifier "token" as used by standard language architectures
        x = x[:, 0]

        x = self.heads(x)

        return x, mask, ids_restore

And then we could add modules for the decoder and for the loss to lightly. Sharing the architecture with torchvision has the simple advantage that exporting and importing weights is very convenient.

@guarin I think you looked into this before. Anything I'm missing?

@guarin
Copy link
Contributor

guarin commented Mar 17, 2022

I remember that the masking + positional encoding part was not trivial. Especially because it is needed in the encoder and the decoder. And it did not fit nicely our workflow where we always return (image, target, filename) tuples. But I am sure we can figure it out.

I would also use the pytorch vit and first focus on implementing the model without any visualizations.

@Atharva-Phatak
Copy link
Contributor

@philippmwirth I agree torchvision version looks simple enough. If everyone concurs then I will study the code in torchvision and we can implement MAE accordingly.

Please let me know.

@philippmwirth
Copy link
Contributor

Sounds great @Atharva-Phatak! Let us know if you need support 🙂

@Atharva-Phatak
Copy link
Contributor

I am sorry this is taking time from my end as I am busy with my final exams 😭 . I will try to create a PR ASAP.

@philippmwirth
Copy link
Contributor

No worries, good luck with your exams!

@philippmwirth
Copy link
Contributor

I've been doing some investigations and somehow this is the best I've come up with on how to make the torchvision ViTs work with Lightly and the MAE setup. @guarin @Atharva-Phatak @IgorSusmelj I'd love to hear your opinions. IMO it's not a very clean approach but I think it should work... If you have better ideas let me know 🙂

Possible approach for an MAE implementation (or rather how to pretrain a torchvision.models.VisionTransformer with MAE:

# initialize ViT
vit = torchvision.models.vit_b_16(pretrained=False)

# use a lightly MaskedEncoder which inherits from torchvision.models.vision_transformer.Encoder
encoder = lightly.modules.MaskedEncoder.from_encoder(vit.encoder)

# use a lightly MaskedDecoder
decoder = lightly.MaskedDecoder()

# use the loss implemented by lightly
loss = MAELoss()

# pre-training
for i in range(epochs):
    for x in dataloader:
        # x is a batch of images (bsz, 3, w, h)

        # need to process the input (patchify & embed)
        x_processed = vit._process_input(x)

        # manually add the cls token
        n = x_processed.shape[0]
        batch_class_token = self.class_token.expand(n, -1, -1)
        x _processed= torch.cat([batch_class_token, x_processed], dim=1)

        # forward pass encoder
        x_encoded, mask, ids_restore = encoder(x_processed)

        # forward pass decoder
        x_decoded = decoder(x_encoded, mask, ids_restore)

        # possibly convert x_decoded to image patches here

        # loss calculation
        l = loss(x, x_decoded)
        # backwards pass etc

# restore original encoder with pretrained weights
vit.encoder = encoder.strip_mask()

This would require us to add the following things to lightly:

  • MaskedEncoder
  • MaskedDecoder
  • MAELoss

@guarin
Copy link
Contributor

guarin commented Apr 23, 2022

This looks great!

Based on your proposal I was able to write the following draft that successfully runs, not sure if it actually works though 🙂

The code is pretty verbose as I have not yet figured out an optimal structure but all the building blocks are there. The main issue is that encoding, masking, and decoding are pretty interleaved and have to share a lot of information between each other. So maybe an overall MAE class that holds the encoder, decoder, class token, and mask token could be a good solution, although this would be a bit against our "low-level" building blocks principle.

This code should also be pretty easy to adapt to the SimMIM and SplitMask models.

The code is adapted from: https://github.com/facebookresearch/mae

from typing import Optional

import torch
import torchvision
import lightly
import tqdm


def repeat_token_like(token, input):
    # repeats token to have same shape as input
    N, S, _ = input.shape
    return token.repeat(N, S, 1)

def expand_index_like(idx, input):
    # expands the index along the feature dimension of input
    # returns idx with shape (N_idx, S_idx, D_input)
    D = input.shape[-1]
    idx = idx.unsqueeze(-1).expand(-1, -1, D)
    return idx

def get_at_index(input, idx):
    # gets tokens at index
    idx = expand_index_like(idx, input)
    return torch.gather(input, 1, idx)

def set_at_index(input, idx, value):
    # sets tokens at index to value
    idx = expand_index_like(idx, input)
    return torch.scatter(input, 1, idx, value)

def prepend_class_token(input, class_token):
    # prepends class token to input
    N = input.shape[0]
    batch_class_token = class_token.expand(N, -1, -1)
    return torch.cat([batch_class_token, input], dim=1)

def create_random_mask(input, mask_ratio=0.6):
    # creates random masks for input
    # returns idx_keep, idx_mask tuple
    # idx_keep has shape (N, num_keep)
    # idx_mask has shape (N, S - num_keep)
    
    # S = sequence length
    N, S, _ = input.shape
    num_keep = int(S * (1 - mask_ratio))
    
    noise = torch.rand(N, S, device=input.device)
    # make sure that class token is not masked
    noise[:, 0] = -1
    
    # get indices of tokens to keep
    indices = torch.argsort(noise, dim=1)
    idx_keep = indices[:, :num_keep]
    idx_mask = indices[:, num_keep:]
    
    return idx_keep, idx_mask

def patchify(imgs, patch_size):
    # converts images into patches
    # output has shape (N, num_patches, patch_size ** 2 * C)
    N, C, H, W = imgs.shape
    assert H == W and H % patch_size == 0

    patch_h = patch_w = H // patch_size
    num_patches = patch_h * patch_w
    patches = imgs.reshape(shape=(N, C, patch_h, patch_size, patch_w, patch_size))
    patches = torch.einsum('nchpwq->nhwpqc', patches)
    patches = patches.reshape(shape=(N, num_patches, patch_size ** 2 * C))
    return patches


class MAEEncoder(torchvision.models.vision_transformer.Encoder):        
    
    def forward(self, input: torch.Tensor, idx_keep: Optional[torch.Tensor] = None):
        torch._assert(input.dim() == 3, f"Expected (batch_size, seq_length, hidden_dim) got {input.shape}")
        input = input + self.pos_embedding
        if idx_keep is not None:
            input = get_at_index(input, idx_keep)
        return self.ln(self.layers(self.dropout(input)))

    @classmethod
    def from_vit_encoder(cls, vit_encoder):
        encoder = cls(
            seq_length=1,
            num_layers=1,
            num_heads=1,
            hidden_dim=1,
            mlp_dim=1,
            dropout=0,
            attention_dropout=0,
        )
        encoder.pos_embedding = vit_encoder.pos_embedding
        encoder.dropout = vit_encoder.dropout
        encoder.layers = vit_encoder.layers
        encoder.ln = vit_encoder.ln
        return encoder


class MAEDecoder(torchvision.models.vision_transformer.Encoder):
    def __init__(
        self, 
        embed_input_dim, 
        patch_size,
        hidden_dim,
        **kwargs,
    ):
        super().__init__(hidden_dim=hidden_dim, **kwargs)
        self.decoder_embed = torch.nn.Linear(embed_input_dim, hidden_dim, bias=True)
        self.prediction_head = torch.nn.Linear(decoder_dim, patch_size ** 2 * 3)
        
    def forward(self, input):
        return self.decode(input)
        
    def embed(self, input):
        return self.decoder_embed(input)
    
    def decode(self, input):
        return super().forward(input)
    
    def predict(self, input):
        return self.prediction_head(input)


vit = torchvision.models.vit_b_32(pretrained=True)

decoder_dim = 512
class_token = vit.class_token
mask_token = torch.nn.Parameter(torch.zeros(1, 1, decoder_dim))

encoder = MAEEncoder.from_vit_encoder(vit.encoder)
decoder = MAEDecoder(
    embed_input_dim=vit.hidden_dim,
    patch_size=vit.patch_size,
    seq_length=vit.seq_length,
    num_layers=1,
    num_heads=4,
    hidden_dim=decoder_dim,
    mlp_dim=decoder_dim * 4,
    dropout=0,
    attention_dropout=0,
)

transform = torchvision.transforms.Compose([
    torchvision.transforms.RandomResizedCrop((vit.image_size, vit.image_size)),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(
        lightly.data.collate.imagenet_normalize['mean'],
        lightly.data.collate.imagenet_normalize['std'],
    )
])

dataset = lightly.data.LightlyDataset('/datasets/aquarium', transform=transform)
dataloader = torch.utils.data.DataLoader(dataset, num_workers=4, batch_size=4, drop_last=True)

optimizer = torch.optim.Adam(
    params=(
        [class_token, mask_token]
        + list(encoder.parameters())
        + list(decoder.parameters())
    ),
    lr=0.06,
)
criterion = torch.nn.MSELoss()


# pre-training
for epoch in range(10):
    epoch_loss = 0
    for imgs, targets, filenames in tqdm.tqdm(dataloader):
        # imgs is a batch of images (bsz, 3, w, h)

        # need to process the input (patchify & embed)
        x_processed = vit._process_input(imgs)

        # add the cls token
        x_processed = prepend_class_token(x_processed, class_token)
        
        # get mask indices
        idx_keep, idx_mask = create_random_mask(x_processed)

        # forward pass encoder, only non-masked tokens are encoded
        x_encoded_keep = encoder(x_processed, idx_keep)
        
        # project to decoder input dimension
        x_decode_embed_keep = decoder.embed(x_encoded_keep)

        # build masked decoder input
        # masked tokens are set to the mask_token
        # non-masked tokens are set to the embedded encoder tokens
        x_masked = repeat_token_like(mask_token, x_processed)
        x_masked = set_at_index(x_masked, idx_keep, x_decode_embed_keep)

        # forward pass decoder
        x_decoded = decoder(x_masked)
        
        # predict pixel values for masked tokens
        x_pred = get_at_index(x_decoded, idx_mask)
        x_pred = decoder.predict(x_pred)
        
        # get image patches for masked tokens
        # must adjust idx_mask for missing class token
        patches = patchify(imgs, vit.patch_size)
        target = get_at_index(patches, idx_mask - 1)
        
        loss = criterion(x_pred, target)
        
        # backwards pass etc
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        
        epoch_loss += loss.detach()
    
    print(epoch, epoch_loss)

@Atharva-Phatak
Copy link
Contributor

Atharva-Phatak commented Apr 23, 2022

@guarin I was implementing Encoder and Decoder as different classes similar to heads and then I was going to combine them in a a single class so that the information sharing will be a bit easier between the modules. I think @guarin seems more elegant and easy to adapt. But that being said I would like to know which approach should I follow ?

@philippmwirth
Copy link
Contributor

Great draft @guarin! 🙂

To answer both your questions: I'd suggest we start building the low-level blocks first (i.e. MAEEncoder and MAEDecoder) similar to what @guarin used above. We can always add a high-level interface which connects the two later. For example, we can add lightly.models.modules.encoders.MAEEncoder and lightly.models.modules.encoders.MAEDecoder in a first step and then later (if necessary) we'll work on lightly.models.mae.MAE.

Ideally, we'd have a working version of the encoder and decoder relatively soon so we can run a quick benchmark on e.g. Imagenette to see if it works as expected. We can then work on the final implementation together 👍

There might be some differences between the original paper and your implementation, @guarin:

  • MAE uses sine-cosine positional embeddings while torchvision uses learned ones (I believe)

That's not dramatic but we should make sure to note it somewhere.

@guarin
Copy link
Contributor

guarin commented Apr 25, 2022

MAE uses sine-cosine positional embeddings while torchvision uses learned ones (I believe)

Aaah good catch, I didn't notice that! Yes we either have to add a note or can overwrite the positional embedding with a sine-cosine one. Although overwriting would break pretrained vits, so maybe that is not the best idea.

Regarding cleanup / code structure:

  • I think we can move all the functions at the beginning of my draft into a some helper file, they will probably be useful again for other vit based models.
  • As @philippmwirth suggested we can move the encoder and decoder into a new encoders module, after this the example code should already be a lot shorter and easier to read.
  • I would maybe wait with introducing a MAE class until we have implemented another similar model, like SimMIM or SplitMask, then it should be easier to see what should be combined and what not.

Next steps would be:

  • Make pr with code split into modules
  • Run benchmark to see if code works --> For the benchmark we have to combine everything into a single module, this will also give some insight on how to do this in the best way.
  • Fix stuff
  • Add unit tests
  • Add docstrings

@Atharva-Phatak do you already have some example code/draft? Would be great if we can compare :)

@Atharva-Phatak
Copy link
Contributor

Hi @guarin I am adding the implementation of encoder. My decoder is same is as yours.
So all in all @guarin we should move ahead with your structure, the only thing is we need to structure the helper utilities very properly.

from utils import random_masking
class MaskedEncoderVIT(torchvision.models.VisionTransformer):

    def forward(self, x: torch.Tensor, mask_ratio: float):

        x = self._process_input(x)
        n = x.shape[0]

        # new: random masking
        x, mask, ids_to_restore = random_masking(x, mask_ratio)  #added random masking in utils.py file
       
        # Expand the class token to the full batch
        batch_class_token = self.class_token.expand(n, -1, -1)
        x = torch.cat([batch_class_token, x], dim=1)

        x = self.encoder(x)
        x = x[:, 0]
        x = self.heads(x)
        return x, mask, ids_to_restore

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants