-
Notifications
You must be signed in to change notification settings - Fork 382
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feat] Add PARSeq model TF and PT #1205
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @NicolasPlaye 👋,
thanks a lot for opening the PR and working on this 👍
First a few general comments:
- the sequence building is done in the base class of a model (ref.: https://github.com/mindee/doctr/blob/main/doctr/models/recognition/vitstr/base.py) No need to add parseq/utils.py
- in the classifiaction models are our backbones so parseq should be removed
Some suggestions how we can go further:
To add the PARSeq model you only need to add 2 files:
- doctr/models/recognition/parseq/base.py (should be the same as vitstr only rename the class name)
- doctr/models/recognition/parseq/pytorch.py (very similar to https://github.com/mindee/doctr/blob/main/doctr/models/recognition/vitstr/pytorch.py)
- in this file you need to add the decoder part of parseq for the encoder we can use our vit_s model (from classification)
I would say let's start with this the other stuff is afterwards are only minor changes :)
@nikokks
|
Hi again;
|
No need to modify the default_cfgs :)
you can init the PARSeqDecoder with this values by default and same for the decoding:
Then update
so we can update the model by passing this config as kwargs :) |
I have modified the file recognition/parseq/pytorch.py |
We should keep our tokenization at the end you can copy paste it from the ViTSTR implementation in doctr :) So what you need to implement is the decoder in parseq/pytorch.py and the forward + compute loss function I would say copy the relevant stuff to parseq/pytorch.py and clean up all this classification additions |
I have done some changes like removing the parseq tokenizer for your vocab. |
Do you think I should implement the decoder in recognition/parseq or in classification/parseq ? |
@odulcy-mindee @charlesmindee @frgfm code is ready for review |
what
what are the commands ? |
For style: And in the tensorflow recognition onnx test (last test case in the file) please set the parseq test also same as above master and sar with the min ram check |
|
@nikokks should be fixed on my branch |
Ok for me for the last commit on my branch !! =) |
Codecov Report
@@ Coverage Diff @@
## main #1205 +/- ##
==========================================
- Coverage 94.73% 93.69% -1.04%
==========================================
Files 150 154 +4
Lines 6458 6903 +445
==========================================
+ Hits 6118 6468 +350
- Misses 340 435 +95
Flags with carried forward coverage won't be shown. Click here to find out more.
|
Thanks @nikokks 👍 now it's fine lets wait for a final review :) |
Thank you @nikokks for the PR and @felixdittrich92 for this review, I'll have a look at it today ! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for your contribution ! Really great work! Code seems fine to me.
I just have 2 questions, you can merge after that.
@felixdittrich92 a new model to add on the training list haha
target = target.clone() + self.attention_dropout( | ||
self.attention(query_norm, content_norm, content_norm, mask=target_mask) | ||
) | ||
target = target.clone() + self.cross_attention_dropout( | ||
self.cross_attention(self.query_norm(target), memory, memory) | ||
) | ||
target = target.clone() + self.feed_forward_dropout(self.position_feed_forward(self.feed_forward_norm(target))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why are there clone
calls here ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pytorch does not allow overriding inplace because it would raise problems in CUDA :)
The other Option would be to use 2 variables but i personally like the clone()
way -> minimizing to code a few lines 😅
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can do an only one clone before line 97 and remove the other clones. I have tried before it worked. To verify to be sure. With that it should speed up processing.
mask = torch.ones((sz, sz), device=permutation.device) | ||
|
||
for i in range(sz): | ||
query_idx = permutation[i] | ||
masked_keys = permutation[i + 1 :] | ||
mask[query_idx, masked_keys] = 0.0 | ||
source_mask = mask[:-1, :-1].clone() | ||
mask[torch.eye(sz, dtype=torch.bool, device=permutation.device)] = 0.0 | ||
target_mask = mask[1:, :-1] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So we're using 0
and 1
for mask and authors used float(-inf)
and 0
respectively, I'm a correct ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep for our MHA implementation 0
is masked (transformer decoder can't "see" it we replace it inside scaled dot product to -inf
this masking is needed for Transformer decoder otherwise the model would be able to "cheat") and 1
is visible the softmax activation does the rest :)
To overcome the question why we don't set it directly to -inf
Short answer: this would raise problems on ONNX exporting :)
🙈 Yes but i would keep it as the last model on the list to train (I still try to debug some things but it's really hard these models with a transformer encoder as backbone needs a ton of data 🥲😅) |
I will do some tests if i am done i will merge it :) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okey, thanks for the answers !
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hello, I'm the original author of PARSeq. @felixdittrich92 asked me to review your implementation. I'll add my comments here even though the PR is already closed.
Overall, it looks correct except for the masking + training loop. Good job and thanks for this initiative. :)
logits = self.decode_non_autoregressive(features) | ||
else: | ||
logits = self.decode_autoregressive(features) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why does target
determine whether the inference mode is AR (target is None
) or NAR (target is not None
)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey 👋 we have already changed it in https://github.com/felixdittrich92/doctr/tree/parseq-fixes (was more for Fürther debugging) :)
# Generate attention masks for the permutations | ||
_, target_mask = self.generate_permutations_attention_masks(perm) | ||
# combine target padding mask and query mask | ||
mask = (target_mask & padding_mask).int() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
target_mask
, as generated by generate_permutations_attention_masks()
, cannot be AND
ed with the padding_mask
. The mask generated from a permutation is shared across all sequences (shape: (max_len, max_len)
), while the padding_mask
varies for each sequence (shape: (N, max_len)
). If you want to AND
both masks, you have to tile the target_mask
for each sequence such that it becomes (N, max_len, max_len)
, and tile padding_mask
for each character output position, i.e. reshape to (N, 1, max_len)
first then tile such that it becomes the same shape.
Personally, at least for PyTorch, it would be better to use the padding_mask
and target_mask
separately since this is handled automatically by the native MHA implementation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@baudm but as you can see we use our own MHA implementation (it is a bit slower as the native implementation of course) but in the past we have had some trouble with the native implementation especially with onnx and it makes it much easier for us to port it to Tensorflow :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pseudocode (not tested):
tiled_target_mask = target_mask.unsqueeze(0).repeat(gt.shape[0], 1, 1) # (N, max_len, max_len)
# tile padding mask for each character output position
tiled_padding_mask = padding_mask.reshape(gt.shape[0], 1, -1).repeat(1, tiled_target_mask.shape[1], 1) # (N, max_len, max_len)
# reshape to (N, max_len, max_len) to match the shape of tiled_target_mask
tiled_target_mask = tiled_target_mask.reshape(gt.shape[0], 1, -1)
# combine target padding mask and query mask
mask = (tiled_target_mask & tiled_padding_mask).int()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@nikokks in this case we need to remove the padding inside generate_permutations function (See function return)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@baudm Does it have any impact if we pad the permutation list to max_length with the eos char to ensure size unified attention masks ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Still stucking 😅
Another problem is how can we pad it to self.max_length
# Create padding mask for target input
# [True, True, True, ..., False, False, False] -> False is masked
padding_mask = ~(((gt == self.vocab_size + 2) | (gt == self.vocab_size)).int().cumsum(-1) > 0)
torch.set_printoptions(profile="full")
if self.training:
# Generate permutations for the target sequences
tgt_perms = self.generate_permutations(seq_len)
print(f"Permutations: {tgt_perms}")
loss = 0
for perm in tgt_perms:
print(f"Permutation: {perm}")
# Generate attention mask for the permutation
_, target_mask = self.generate_permutations_attention_masks(perm)
key_padding_mask_expanded = padding_mask[:, :target_mask.shape[-1]].view(features.shape[0], 1, 1, target_mask.shape[-1]).expand(-1, 1, -1, -1)
print(f"key_padding_mask_expanded shape: {key_padding_mask_expanded.shape}")
print(f"key_padding_mask_expanded: \n{key_padding_mask_expanded}")
target_mask = target_mask.view(1, 1, target_mask.shape[-1], target_mask.shape[-1]).expand(features.shape[0], 1, -1, -1)
print(f"target_mask shape: {target_mask.shape}")
print(f"target_mask: \n{target_mask}")
mask = (key_padding_mask_expanded.bool() & target_mask.bool()).int()
print(f"mask shape: {mask.shape}")
print(f"mask: \n{mask}")
logits = self.head(self.decode(gt[:, :target_mask.shape[-1]], features, mask)) # (N, max_length, vocab_size + 1)
print(f"logits shape: {logits.shape}")
Permutations: tensor([[0, 1, 2, 3, 4, 5, 6, 7],
[0, 7, 6, 5, 4, 3, 2, 1],
[0, 3, 6, 2, 5, 1, 4, 7],
[0, 4, 1, 5, 2, 6, 3, 7],
[0, 1, 6, 3, 5, 2, 4, 7],
[0, 4, 2, 5, 3, 6, 1, 7]], device='cuda:0', dtype=torch.int32)
Permutation: tensor([0, 1, 2, 3, 4, 5, 6, 7], device='cuda:0', dtype=torch.int32)
key_padding_mask_expanded shape: torch.Size([3, 1, 1, 7])
key_padding_mask_expanded:
tensor([[[[ True, True, True, True, True, True, False]]],
[[[ True, True, True, False, False, False, False]]],
[[[ True, True, True, False, False, False, False]]]],
device='cuda:0')
target_mask shape: torch.Size([3, 1, 7, 7])
target_mask:
tensor([[[[1, 0, 0, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0, 0],
[1, 1, 1, 0, 0, 0, 0],
[1, 1, 1, 1, 0, 0, 0],
[1, 1, 1, 1, 1, 0, 0],
[1, 1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 1, 1]]],
[[[1, 0, 0, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0, 0],
[1, 1, 1, 0, 0, 0, 0],
[1, 1, 1, 1, 0, 0, 0],
[1, 1, 1, 1, 1, 0, 0],
[1, 1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 1, 1]]],
[[[1, 0, 0, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0, 0],
[1, 1, 1, 0, 0, 0, 0],
[1, 1, 1, 1, 0, 0, 0],
[1, 1, 1, 1, 1, 0, 0],
[1, 1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 1, 1]]]], device='cuda:0', dtype=torch.int32)
mask shape: torch.Size([3, 1, 7, 7])
mask:
tensor([[[[1, 0, 0, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0, 0],
[1, 1, 1, 0, 0, 0, 0],
[1, 1, 1, 1, 0, 0, 0],
[1, 1, 1, 1, 1, 0, 0],
[1, 1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 1, 0]]],
[[[1, 0, 0, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0, 0],
[1, 1, 1, 0, 0, 0, 0],
[1, 1, 1, 0, 0, 0, 0],
[1, 1, 1, 0, 0, 0, 0],
[1, 1, 1, 0, 0, 0, 0],
[1, 1, 1, 0, 0, 0, 0]]],
[[[1, 0, 0, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0, 0],
[1, 1, 1, 0, 0, 0, 0],
[1, 1, 1, 0, 0, 0, 0],
[1, 1, 1, 0, 0, 0, 0],
[1, 1, 1, 0, 0, 0, 0],
[1, 1, 1, 0, 0, 0, 0]]]], device='cuda:0', dtype=torch.int32)
logits shape: torch.Size([3, 7, 127])
Permutation: tensor([0, 7, 6, 5, 4, 3, 2, 1], device='cuda:0', dtype=torch.int32)
key_padding_mask_expanded shape: torch.Size([3, 1, 1, 7])
key_padding_mask_expanded:
tensor([[[[ True, True, True, True, True, True, False]]],
[[[ True, True, True, False, False, False, False]]],
[[[ True, True, True, False, False, False, False]]]],
device='cuda:0')
target_mask shape: torch.Size([3, 1, 7, 7])
target_mask:
tensor([[[[1, 0, 1, 1, 1, 1, 1],
[1, 0, 0, 1, 1, 1, 1],
[1, 0, 0, 0, 1, 1, 1],
[1, 0, 0, 0, 0, 1, 1],
[1, 0, 0, 0, 0, 0, 1],
[1, 0, 0, 0, 0, 0, 0],
[1, 0, 0, 0, 0, 0, 0]]],
[[[1, 0, 1, 1, 1, 1, 1],
[1, 0, 0, 1, 1, 1, 1],
[1, 0, 0, 0, 1, 1, 1],
[1, 0, 0, 0, 0, 1, 1],
[1, 0, 0, 0, 0, 0, 1],
[1, 0, 0, 0, 0, 0, 0],
[1, 0, 0, 0, 0, 0, 0]]],
[[[1, 0, 1, 1, 1, 1, 1],
[1, 0, 0, 1, 1, 1, 1],
[1, 0, 0, 0, 1, 1, 1],
[1, 0, 0, 0, 0, 1, 1],
[1, 0, 0, 0, 0, 0, 1],
[1, 0, 0, 0, 0, 0, 0],
[1, 0, 0, 0, 0, 0, 0]]]], device='cuda:0', dtype=torch.int32)
mask shape: torch.Size([3, 1, 7, 7])
mask:
tensor([[[[1, 0, 1, 1, 1, 1, 0],
[1, 0, 0, 1, 1, 1, 0],
[1, 0, 0, 0, 1, 1, 0],
[1, 0, 0, 0, 0, 1, 0],
[1, 0, 0, 0, 0, 0, 0],
[1, 0, 0, 0, 0, 0, 0],
[1, 0, 0, 0, 0, 0, 0]]],
[[[1, 0, 1, 0, 0, 0, 0],
[1, 0, 0, 0, 0, 0, 0],
[1, 0, 0, 0, 0, 0, 0],
[1, 0, 0, 0, 0, 0, 0],
[1, 0, 0, 0, 0, 0, 0],
[1, 0, 0, 0, 0, 0, 0],
[1, 0, 0, 0, 0, 0, 0]]],
[[[1, 0, 1, 0, 0, 0, 0],
[1, 0, 0, 0, 0, 0, 0],
[1, 0, 0, 0, 0, 0, 0],
[1, 0, 0, 0, 0, 0, 0],
[1, 0, 0, 0, 0, 0, 0],
[1, 0, 0, 0, 0, 0, 0],
[1, 0, 0, 0, 0, 0, 0]]]], device='cuda:0', dtype=torch.int32)
logits shape: torch.Size([3, 7, 127])
Permutation: tensor([0, 3, 6, 2, 5, 1, 4, 7], device='cuda:0', dtype=torch.int32)
key_padding_mask_expanded shape: torch.Size([3, 1, 1, 7])
key_padding_mask_expanded:
tensor([[[[ True, True, True, True, True, True, False]]],
[[[ True, True, True, False, False, False, False]]],
[[[ True, True, True, False, False, False, False]]]],
device='cuda:0')
target_mask shape: torch.Size([3, 1, 7, 7])
target_mask:
tensor([[[[1, 0, 1, 1, 0, 1, 1],
[1, 0, 0, 1, 0, 0, 1],
[1, 0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 0, 1, 1],
[1, 0, 1, 1, 0, 0, 1],
[1, 0, 0, 1, 0, 0, 0],
[1, 1, 1, 1, 1, 1, 1]]],
[[[1, 0, 1, 1, 0, 1, 1],
[1, 0, 0, 1, 0, 0, 1],
[1, 0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 0, 1, 1],
[1, 0, 1, 1, 0, 0, 1],
[1, 0, 0, 1, 0, 0, 0],
[1, 1, 1, 1, 1, 1, 1]]],
[[[1, 0, 1, 1, 0, 1, 1],
[1, 0, 0, 1, 0, 0, 1],
[1, 0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 0, 1, 1],
[1, 0, 1, 1, 0, 0, 1],
[1, 0, 0, 1, 0, 0, 0],
[1, 1, 1, 1, 1, 1, 1]]]], device='cuda:0', dtype=torch.int32)
mask shape: torch.Size([3, 1, 7, 7])
mask:
tensor([[[[1, 0, 1, 1, 0, 1, 0],
[1, 0, 0, 1, 0, 0, 0],
[1, 0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 0, 1, 0],
[1, 0, 1, 1, 0, 0, 0],
[1, 0, 0, 1, 0, 0, 0],
[1, 1, 1, 1, 1, 1, 0]]],
[[[1, 0, 1, 0, 0, 0, 0],
[1, 0, 0, 0, 0, 0, 0],
[1, 0, 0, 0, 0, 0, 0],
[1, 1, 1, 0, 0, 0, 0],
[1, 0, 1, 0, 0, 0, 0],
[1, 0, 0, 0, 0, 0, 0],
[1, 1, 1, 0, 0, 0, 0]]],
[[[1, 0, 1, 0, 0, 0, 0],
[1, 0, 0, 0, 0, 0, 0],
[1, 0, 0, 0, 0, 0, 0],
[1, 1, 1, 0, 0, 0, 0],
[1, 0, 1, 0, 0, 0, 0],
[1, 0, 0, 0, 0, 0, 0],
[1, 1, 1, 0, 0, 0, 0]]]], device='cuda:0', dtype=torch.int32)
logits shape: torch.Size([3, 7, 127])
Permutation: tensor([0, 4, 1, 5, 2, 6, 3, 7], device='cuda:0', dtype=torch.int32)
key_padding_mask_expanded shape: torch.Size([3, 1, 1, 7])
key_padding_mask_expanded:
tensor([[[[ True, True, True, True, True, True, False]]],
[[[ True, True, True, False, False, False, False]]],
[[[ True, True, True, False, False, False, False]]]],
device='cuda:0')
target_mask shape: torch.Size([3, 1, 7, 7])
target_mask:
tensor([[[[1, 0, 0, 0, 1, 0, 0],
[1, 1, 0, 0, 1, 1, 0],
[1, 1, 1, 0, 1, 1, 1],
[1, 0, 0, 0, 0, 0, 0],
[1, 1, 0, 0, 1, 0, 0],
[1, 1, 1, 0, 1, 1, 0],
[1, 1, 1, 1, 1, 1, 1]]],
[[[1, 0, 0, 0, 1, 0, 0],
[1, 1, 0, 0, 1, 1, 0],
[1, 1, 1, 0, 1, 1, 1],
[1, 0, 0, 0, 0, 0, 0],
[1, 1, 0, 0, 1, 0, 0],
[1, 1, 1, 0, 1, 1, 0],
[1, 1, 1, 1, 1, 1, 1]]],
[[[1, 0, 0, 0, 1, 0, 0],
[1, 1, 0, 0, 1, 1, 0],
[1, 1, 1, 0, 1, 1, 1],
[1, 0, 0, 0, 0, 0, 0],
[1, 1, 0, 0, 1, 0, 0],
[1, 1, 1, 0, 1, 1, 0],
[1, 1, 1, 1, 1, 1, 1]]]], device='cuda:0', dtype=torch.int32)
mask shape: torch.Size([3, 1, 7, 7])
mask:
tensor([[[[1, 0, 0, 0, 1, 0, 0],
[1, 1, 0, 0, 1, 1, 0],
[1, 1, 1, 0, 1, 1, 0],
[1, 0, 0, 0, 0, 0, 0],
[1, 1, 0, 0, 1, 0, 0],
[1, 1, 1, 0, 1, 1, 0],
[1, 1, 1, 1, 1, 1, 0]]],
[[[1, 0, 0, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0, 0],
[1, 1, 1, 0, 0, 0, 0],
[1, 0, 0, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0, 0],
[1, 1, 1, 0, 0, 0, 0],
[1, 1, 1, 0, 0, 0, 0]]],
[[[1, 0, 0, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0, 0],
[1, 1, 1, 0, 0, 0, 0],
[1, 0, 0, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0, 0],
[1, 1, 1, 0, 0, 0, 0],
[1, 1, 1, 0, 0, 0, 0]]]], device='cuda:0', dtype=torch.int32)
logits shape: torch.Size([3, 7, 127])
Permutation: tensor([0, 1, 6, 3, 5, 2, 4, 7], device='cuda:0', dtype=torch.int32)
key_padding_mask_expanded shape: torch.Size([3, 1, 1, 7])
key_padding_mask_expanded:
tensor([[[[ True, True, True, True, True, True, False]]],
[[[ True, True, True, False, False, False, False]]],
[[[ True, True, True, False, False, False, False]]]],
device='cuda:0')
target_mask shape: torch.Size([3, 1, 7, 7])
target_mask:
tensor([[[[1, 0, 0, 0, 0, 0, 0],
[1, 1, 0, 1, 0, 1, 1],
[1, 1, 0, 0, 0, 0, 1],
[1, 1, 1, 1, 0, 1, 1],
[1, 1, 0, 1, 0, 0, 1],
[1, 1, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1, 1]]],
[[[1, 0, 0, 0, 0, 0, 0],
[1, 1, 0, 1, 0, 1, 1],
[1, 1, 0, 0, 0, 0, 1],
[1, 1, 1, 1, 0, 1, 1],
[1, 1, 0, 1, 0, 0, 1],
[1, 1, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1, 1]]],
[[[1, 0, 0, 0, 0, 0, 0],
[1, 1, 0, 1, 0, 1, 1],
[1, 1, 0, 0, 0, 0, 1],
[1, 1, 1, 1, 0, 1, 1],
[1, 1, 0, 1, 0, 0, 1],
[1, 1, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1, 1]]]], device='cuda:0', dtype=torch.int32)
mask shape: torch.Size([3, 1, 7, 7])
mask:
tensor([[[[1, 0, 0, 0, 0, 0, 0],
[1, 1, 0, 1, 0, 1, 0],
[1, 1, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 0, 1, 0],
[1, 1, 0, 1, 0, 0, 0],
[1, 1, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1, 0]]],
[[[1, 0, 0, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0, 0],
[1, 1, 1, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0, 0],
[1, 1, 1, 0, 0, 0, 0]]],
[[[1, 0, 0, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0, 0],
[1, 1, 1, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0, 0],
[1, 1, 1, 0, 0, 0, 0]]]], device='cuda:0', dtype=torch.int32)
logits shape: torch.Size([3, 7, 127])
Permutation: tensor([0, 4, 2, 5, 3, 6, 1, 7], device='cuda:0', dtype=torch.int32)
key_padding_mask_expanded shape: torch.Size([3, 1, 1, 7])
key_padding_mask_expanded:
tensor([[[[ True, True, True, True, True, True, False]]],
[[[ True, True, True, False, False, False, False]]],
[[[ True, True, True, False, False, False, False]]]],
device='cuda:0')
target_mask shape: torch.Size([3, 1, 7, 7])
target_mask:
tensor([[[[1, 0, 1, 1, 1, 1, 1],
[1, 0, 0, 0, 1, 0, 0],
[1, 0, 1, 0, 1, 1, 0],
[1, 0, 0, 0, 0, 0, 0],
[1, 0, 1, 0, 1, 0, 0],
[1, 0, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 1, 1]]],
[[[1, 0, 1, 1, 1, 1, 1],
[1, 0, 0, 0, 1, 0, 0],
[1, 0, 1, 0, 1, 1, 0],
[1, 0, 0, 0, 0, 0, 0],
[1, 0, 1, 0, 1, 0, 0],
[1, 0, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 1, 1]]],
[[[1, 0, 1, 1, 1, 1, 1],
[1, 0, 0, 0, 1, 0, 0],
[1, 0, 1, 0, 1, 1, 0],
[1, 0, 0, 0, 0, 0, 0],
[1, 0, 1, 0, 1, 0, 0],
[1, 0, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 1, 1]]]], device='cuda:0', dtype=torch.int32)
mask shape: torch.Size([3, 1, 7, 7])
mask:
tensor([[[[1, 0, 1, 1, 1, 1, 0],
[1, 0, 0, 0, 1, 0, 0],
[1, 0, 1, 0, 1, 1, 0],
[1, 0, 0, 0, 0, 0, 0],
[1, 0, 1, 0, 1, 0, 0],
[1, 0, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 1, 0]]],
[[[1, 0, 1, 0, 0, 0, 0],
[1, 0, 0, 0, 0, 0, 0],
[1, 0, 1, 0, 0, 0, 0],
[1, 0, 0, 0, 0, 0, 0],
[1, 0, 1, 0, 0, 0, 0],
[1, 0, 1, 0, 0, 0, 0],
[1, 1, 1, 0, 0, 0, 0]]],
[[[1, 0, 1, 0, 0, 0, 0],
[1, 0, 0, 0, 0, 0, 0],
[1, 0, 1, 0, 0, 0, 0],
[1, 0, 0, 0, 0, 0, 0],
[1, 0, 1, 0, 0, 0, 0],
[1, 0, 1, 0, 0, 0, 0],
[1, 1, 1, 0, 0, 0, 0]]]], device='cuda:0', dtype=torch.int32)
logits shape: torch.Size([3, 7, 127])
# One refine iteration | ||
# Update query mask | ||
query_mask[ | ||
torch.triu( | ||
torch.ones(self.max_length + 1, self.max_length + 1, dtype=torch.bool, device=features.device), 2 | ||
) | ||
] = 1 | ||
|
||
# Prepare target input for 1 refine iteration | ||
sos = torch.full((features.size(0), 1), self.vocab_size + 1, dtype=torch.long, device=features.device) | ||
ys = torch.cat([sos, logits[:, :-1].argmax(-1)], dim=1) | ||
|
||
# Create padding mask for refined target input maskes all behind EOS token as False | ||
# (N, 1, 1, max_length) | ||
target_pad_mask = ~((ys == self.vocab_size).int().cumsum(-1) > 0).unsqueeze(1).unsqueeze(1) | ||
mask = (target_pad_mask.bool() & query_mask[:, : ys.shape[1]].bool()).int() | ||
logits = self.head(self.decode(ys, features, mask, target_query=pos_queries)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The refinement process can be done regardless of the initial decoding scheme (AR or NAR). I suggest moving this to a separate method so it can be used by either AR or NAR decoding.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍
Hi I am going to add PARSeq model to the list of doctr models.
This PR:
Any feedback is welcome :)