Skip to content

Commit

Permalink
clearer mae
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Nov 22, 2021
1 parent 5ae5557 commit 9f8c606
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions vit_pytorch/mae.py
Expand Up @@ -78,12 +78,12 @@ def forward(self, img):

# concat the masked tokens to the decoder tokens and attend with decoder

decoder_tokens = torch.cat((decoder_tokens, mask_tokens), dim = 1)
decoder_tokens = torch.cat((mask_tokens, decoder_tokens), dim = 1)
decoded_tokens = self.decoder(decoder_tokens)

# splice out the mask tokens and project to pixel values

mask_tokens = decoded_tokens[:, -num_masked:]
mask_tokens = decoded_tokens[:, :num_masked]
pred_pixel_values = self.to_pixels(mask_tokens)

# calculate reconstruction loss
Expand Down

0 comments on commit 9f8c606

Please sign in to comment.