Skip to content

Commit

Permalink
minor
Browse files Browse the repository at this point in the history
  • Loading branch information
Sebastien Ehrhardt committed May 14, 2024
1 parent 38642b5 commit fe66ab7
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions src/transformers/models/vit_mae/modeling_vit_mae.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,18 +241,18 @@ def random_masking(self, sequence, noise=None):
noise = torch.rand(batch_size, seq_length, device=sequence.device) # noise in [0, 1]

# sort noise for each sample
ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove
ids_restore = torch.argsort(ids_shuffle, dim=1)
ids_shuffle = torch.argsort(noise, dim=1).to(sequence.device) # ascend: small is keep, large is remove
ids_restore = torch.argsort(ids_shuffle, dim=1).to(sequence.device)

# keep the first subset
ids_keep = ids_shuffle[:, :len_keep].to(sequence.device)
ids_keep = ids_shuffle[:, :len_keep]
sequence_unmasked = torch.gather(sequence, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, dim))

# generate the binary mask: 0 is keep, 1 is remove
mask = torch.ones([batch_size, seq_length], device=sequence.device)
mask[:, :len_keep] = 0
# unshuffle to get the binary mask
mask = torch.gather(mask, dim=1, index=ids_restore.to(sequence.device))
mask = torch.gather(mask, dim=1, index=ids_restore)

return sequence_unmasked, mask, ids_restore

Expand Down

0 comments on commit fe66ab7

Please sign in to comment.