Skip to content

Commit

Permalink
address #306
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Apr 18, 2024
1 parent 12249dc commit 96f66d2
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
setup(
name = 'vit-pytorch',
packages = find_packages(exclude=['examples']),
version = '1.6.6',
version = '1.6.7',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
long_description=long_description,
Expand Down
10 changes: 5 additions & 5 deletions vit_pytorch/na_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, ml
self.calc_token_dropout = token_dropout_prob

elif isinstance(token_dropout_prob, (float, int)):
assert 0. < token_dropout_prob < 1.
assert 0. <= token_dropout_prob < 1.
token_dropout_prob = float(token_dropout_prob)
self.calc_token_dropout = lambda height, width: token_dropout_prob

Expand Down Expand Up @@ -249,7 +249,7 @@ def forward(
group_images = False,
group_max_seq_len = 2048
):
p, c, device, has_token_dropout = self.patch_size, self.channels, self.device, exists(self.calc_token_dropout)
p, c, device, has_token_dropout = self.patch_size, self.channels, self.device, exists(self.calc_token_dropout) and self.training

arange = partial(torch.arange, device = device)
pad_sequence = partial(orig_pad_sequence, batch_first = True)
Expand All @@ -260,7 +260,7 @@ def forward(
batched_images = group_images_by_max_seq_len(
batched_images,
patch_size = self.patch_size,
calc_token_dropout = self.calc_token_dropout,
calc_token_dropout = self.calc_token_dropout if self.training else None,
max_seq_len = group_max_seq_len
)

Expand Down Expand Up @@ -314,8 +314,8 @@ def forward(
# derive key padding mask

lengths = torch.tensor([seq.shape[-2] for seq in batched_sequences], device = device, dtype = torch.long)
max_length = arange(lengths.amax().item())
key_pad_mask = rearrange(lengths, 'b -> b 1') <= rearrange(max_length, 'n -> 1 n')
seq_arange = arange(lengths.amax().item())
key_pad_mask = rearrange(seq_arange, 'n -> 1 n') < rearrange(lengths, 'b -> b 1')

# derive attention mask, and combine with key padding mask from above

Expand Down

0 comments on commit 96f66d2

Please sign in to comment.