Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feat] Add PARSeq model TF and PT #1205

Merged
merged 107 commits into from
Jun 15, 2023
Merged

[Feat] Add PARSeq model TF and PT #1205

merged 107 commits into from
Jun 15, 2023

Conversation

nikokks
Copy link
Contributor

@nikokks nikokks commented May 30, 2023

Hi I am going to add PARSeq model to the list of doctr models.

This PR:

  • adds PARSeq tensorflow implementation
  • adds PARSeq pytorch implementation
  • adds corresponding tests

Any feedback is welcome :)

@felixdittrich92 felixdittrich92 added this to the 0.6.1 milestone May 30, 2023
@felixdittrich92 felixdittrich92 added topic: documentation Improvements or additions to documentation module: models Related to doctr.models ext: tests Related to tests folder framework: pytorch Related to PyTorch backend topic: text recognition Related to the task of text recognition type: new feature New feature ext: docs Related to docs folder labels May 30, 2023
Copy link
Contributor

@felixdittrich92 felixdittrich92 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @NicolasPlaye 👋,

thanks a lot for opening the PR and working on this 👍

First a few general comments:

Some suggestions how we can go further:

To add the PARSeq model you only need to add 2 files:

I would say let's start with this the other stuff is afterwards are only minor changes :)

@felixdittrich92
Copy link
Contributor

@nikokks
template with todos for parseq/pytorch.py: (maybe more helpful to explain what you need to do for the implementation)

# Copyright (C) 2021-2023, Mindee.

# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.

from copy import deepcopy
from typing import Any, Callable, Dict, List, Optional, Tuple

import torch
from torch import nn
from torch.nn import functional as F
from torchvision.models._utils import IntermediateLayerGetter

from doctr.datasets import VOCABS

from ...classification import vit_s
from ...utils.pytorch import load_pretrained_params
from .base import _PARSeq, _PARSeqPostProcessor

__all__ = ["PARSeq", "parseq"]

default_cfgs: Dict[str, Dict[str, Any]] = {
    "parseq": {
        "mean": (0.694, 0.695, 0.693),
        "std": (0.299, 0.296, 0.301),
        "input_shape": (3, 32, 128),
        "vocab": VOCABS["french"],
        "url": None,
    },
}

class PARSeqDecoder(nn.Module):
    """Implements decoder module of the PARSeq model

    Args:
        TODO

    """
    # TODO

class PARSeq(_PARSeq, nn.Module):
    """Implements a PARSeq architecture as described in `"Scene Text Recognition
    with Permuted Autoregressive Sequence Models" <https://arxiv.org/pdf/2207.06966>`_.

    Args:
        feature_extractor: the backbone serving as feature extractor
        vocab: vocabulary used for encoding
        embedding_units: number of embedding units
        max_length: maximum word length handled by the model
        dropout_prob: dropout probability of the encoder LSTM
        input_shape: input shape of the image
        exportable: onnx exportable returns only logits
        cfg: dictionary containing information about the model
    """

    def __init__(
        self,
        feature_extractor,
        vocab: str,
        embedding_units: int,
        max_length: int = 25,
        input_shape: Tuple[int, int, int] = (3, 32, 128),
        exportable: bool = False,
        cfg: Optional[Dict[str, Any]] = None,
    ) -> None:
        super().__init__()
        self.vocab = vocab
        self.exportable = exportable
        self.cfg = cfg
        self.max_length = max_length + 3  # Add 1 step for EOS, 1 for SOS, 1 for PAD

        self.feat_extractor = feature_extractor
        self.decoder = PARSeqDecoder() # TODO
        self.head = nn.Linear(embedding_units, len(self.vocab) + 3)

        self.postprocessor = PARSeqPostProcessor(vocab=self.vocab)

    def forward(
        self,
        x: torch.Tensor,
        target: Optional[List[str]] = None,
        return_model_output: bool = False,
        return_preds: bool = False,
    ) -> Dict[str, Any]:
        features = self.feat_extractor(x)["features"]  # (batch_size, patches_seqlen, d_model)

        if target is not None:
            _gt, _seq_len = self.build_target(target)
            gt, seq_len = torch.from_numpy(_gt).to(dtype=torch.long), torch.tensor(_seq_len)
            gt, seq_len = gt.to(x.device), seq_len.to(x.device)

        if self.training and target is None:
            raise ValueError("Need to provide labels during training")

        # TODO

        out: Dict[str, Any] = {}
        if self.exportable:
            out["logits"] = decoded_features
            return out

        if return_model_output:
            out["out_map"] = decoded_features

        if target is None or return_preds:
            # Post-process boxes
            out["preds"] = self.postprocessor(decoded_features)

        if target is not None:
            out["loss"] = self.compute_loss(decoded_features, gt, seq_len)

        return out

    @staticmethod
    def compute_loss(
        model_output: torch.Tensor,
        gt: torch.Tensor,
        seq_len: torch.Tensor,
    ) -> torch.Tensor:
        """Compute categorical cross-entropy loss for the model.
        Sequences are masked after the EOS character.

        Args:
            model_output: predicted logits of the model
            gt: the encoded tensor with gt labels
            seq_len: lengths of each gt word inside the batch

        Returns:
            The loss of the model on the batch
        """
        # TODO


class PARSeqPostProcessor(_PARSeqPostProcessor):
    """Post processor for PARSeq architecture

    Args:
        vocab: string containing the ordered sequence of supported characters
    """

    def __call__(
        self,
        logits: torch.Tensor,
    ) -> List[Tuple[str, float]]:
        # TODO


def _parseq(
    arch: str,
    pretrained: bool,
    backbone_fn: Callable[[bool], nn.Module],
    layer: str,
    pretrained_backbone: bool = True,
    ignore_keys: Optional[List[str]] = None,
    **kwargs: Any,
) -> PARSeq:
    pretrained_backbone = pretrained_backbone and not pretrained

    # Patch the config
    _cfg = deepcopy(default_cfgs[arch])
    _cfg["vocab"] = kwargs.get("vocab", _cfg["vocab"])
    _cfg["input_shape"] = kwargs.get("input_shape", _cfg["input_shape"])

    kwargs["vocab"] = _cfg["vocab"]
    kwargs["input_shape"] = _cfg["input_shape"]

    # Feature extractor
    feat_extractor = IntermediateLayerGetter(
        backbone_fn(pretrained_backbone, input_shape=_cfg["input_shape"]),  # type: ignore[call-arg]
        {layer: "features"},
    )

    # Build the model
    model = PARSeq(feat_extractor, cfg=_cfg, **kwargs)
    # Load pretrained parameters
    if pretrained:
        # The number of classes is not the same as the number of classes in the pretrained model =>
        # remove the last layer weights
        _ignore_keys = ignore_keys if _cfg["vocab"] != default_cfgs[arch]["vocab"] else None
        load_pretrained_params(model, default_cfgs[arch]["url"], ignore_keys=_ignore_keys)

    return model


def parseq(pretrained: bool = False, **kwargs: Any) -> PARSeq:
    """PARSeq architecture from
    `"Scene Text Recognition with Permuted Autoregressive Sequence Models" <https://arxiv.org/pdf/2207.06966>`_.

    >>> import torch
    >>> from doctr.models import parseq
    >>> model = parseq(pretrained=False)
    >>> input_tensor = torch.rand((1, 3, 32, 128))
    >>> out = model(input_tensor)

    Args:
        pretrained (bool): If True, returns a model pre-trained on our text recognition dataset

    Returns:
        text recognition architecture
    """

    return _parseq(
        "parseq",
        pretrained,
        vit_s,
        "1",
        embedding_units=384,
        ignore_keys=["head.weight", "head.bias"],
        **kwargs,
    )

docs/source/modules/models.rst Outdated Show resolved Hide resolved
doctr/models/classification/__init__.py Outdated Show resolved Hide resolved
doctr/models/classification/parseq/base.py Outdated Show resolved Hide resolved
doctr/models/classification/parseq/modules.py Outdated Show resolved Hide resolved
doctr/models/classification/parseq/utils.py Outdated Show resolved Hide resolved
doctr/models/classification/parseq/visiontransformers.py Outdated Show resolved Hide resolved
doctr/models/classification/zoo.py Outdated Show resolved Hide resolved
doctr/models/recognition/parseq/base.py Show resolved Hide resolved
doctr/models/recognition/zoo.py Show resolved Hide resolved
@nikokks
Copy link
Contributor Author

nikokks commented May 30, 2023

Hi again;
where do I put the config model ?

default_cfgs: Dict[str, Dict[str, Any]] = {
    "parseq": {
        "mean": (0.694, 0.695, 0.693),
        "std": (0.299, 0.296, 0.301),
        "charset_train": "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~",
        "charset_test": "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~" ,
        "max_label_length": 25 ,
        "batch_size": 384,
        "lr": 7e-4,
        "warmup_pct": 0.075,
        "weight_decay": 0.0,
        "img_size": [ 32, 128 ],
        "patch_size": [ 4, 8 ] ,
        "embed_dim": 384 ,
        "enc_num_heads": 6,
        "enc_mlp_ratio": 4,
        "enc_depth": 12,
        "dec_num_heads": 12,
        "dec_mlp_ratio": 4 ,
        "dec_depth": 1,
        "perm_num": 6 ,
        "perm_forward": True ,
        "perm_mirrored": True ,
        "decode_ar": True,
        "refine_iters": 1,
        "dropout": 0.1,
        "vocab": VOCABS["french"],
        "input_shape": (3, 32, 128),
        "classes": list(VOCABS["french"]),
        "url": "/home/nikkokks/Desktop/github/parseq-bb5792a6.pt",
        }
}```

@felixdittrich92
Copy link
Contributor

felixdittrich92 commented May 30, 2023

Hi again; where do I put the config model ?

default_cfgs: Dict[str, Dict[str, Any]] = {
    "parseq": {
        "mean": (0.694, 0.695, 0.693),
        "std": (0.299, 0.296, 0.301),
        "charset_train": "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~",
        "charset_test": "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~" ,
        "max_label_length": 25 ,
        "batch_size": 384,
        "lr": 7e-4,
        "warmup_pct": 0.075,
        "weight_decay": 0.0,
        "img_size": [ 32, 128 ],
        "patch_size": [ 4, 8 ] ,
        "embed_dim": 384 ,
        "enc_num_heads": 6,
        "enc_mlp_ratio": 4,
        "enc_depth": 12,
        "dec_num_heads": 12,
        "dec_mlp_ratio": 4 ,
        "dec_depth": 1,
        "perm_num": 6 ,
        "perm_forward": True ,
        "perm_mirrored": True ,
        "decode_ar": True,
        "refine_iters": 1,
        "dropout": 0.1,
        "vocab": VOCABS["french"],
        "input_shape": (3, 32, 128),
        "classes": list(VOCABS["french"]),
        "url": "/home/nikkokks/Desktop/github/parseq-bb5792a6.pt",
        }
}```

No need to modify the default_cfgs :)
The only values you need are the onces for the decoder part:

"dec_num_heads": 12,
"dec_mlp_ratio": 4 ,
"dec_depth": 1,

you can init the PARSeqDecoder with this values by default and same for the decoding:

"perm_num": 6 ,
"perm_forward": True ,
"perm_mirrored": True ,
"decode_ar": True,
"refine_iters": 1,

Then update

class PARSeq(_PARSeq, nn.Module):
    
    def __init__(
        self,
        feature_extractor,
        vocab: str,
        embedding_units: int,
        max_length: int = 25,
        input_shape: Tuple[int, int, int] = (3, 32, 128),
        exportable: bool = False,
        perm_num: int = 6,
        perm_forward: bool = True,
        ....

so we can update the model by passing this config as kwargs :)

@nikokks
Copy link
Contributor Author

nikokks commented May 30, 2023

I have modified the file recognition/parseq/pytorch.py
I did not implement the decoder because it is in classification/parseq
to end with recognition/parseq/pytorch.py, what do you think about it?
I do not use the vocab file in postprocessor but the tokenizer delivered with the model. Is it ok ?

@felixdittrich92
Copy link
Contributor

I have modified the file recognition/parseq/pytorch.py I did not implement the decoder because it is in classification/parseq to end with recognition/parseq/pytorch.py, what do you think about it? I do not use the vocab file in postprocessor but the tokenizer delivered with the model. Is it ok ?

We should keep our tokenization at the end you can copy paste it from the ViTSTR implementation in doctr :)
We should not touch the classification folder (that's the place for the backbone / encoder models) but in the template i have already added the backbone loading (for parseq which is vit_s)

So what you need to implement is the decoder in parseq/pytorch.py and the forward + compute loss function

I would say copy the relevant stuff to parseq/pytorch.py and clean up all this classification additions
if you need some help to integrate the decoder / forward i am happy to help :)

@nikokks
Copy link
Contributor Author

nikokks commented May 30, 2023

I have done some changes like removing the parseq tokenizer for your vocab.
Do you have others advices for recognition/parseq ?

@nikokks
Copy link
Contributor Author

nikokks commented May 30, 2023

Do you think I should implement the decoder in recognition/parseq or in classification/parseq ?
And for the forward and compute loss of PARSeq , do I put it in recognition/parseq or just in classification/parseq ?

@felixdittrich92
Copy link
Contributor

@odulcy-mindee @charlesmindee @frgfm code is ready for review
@nikokks could you please fix the style and mypy issues ? :)

@nikokks
Copy link
Contributor Author

nikokks commented Jun 13, 2023

what

@odulcy-mindee @charlesmindee @frgfm code is ready for review @nikokks could you please fix the style and mypy issues ? :)

what are the commands ?

@felixT2K
Copy link
Contributor

what

@odulcy-mindee @charlesmindee @frgfm code is ready for review @nikokks could you please fix the style and mypy issues ? :)

what are the commands ?

For style: make style
For quality: make quality this will list the mypy issues you need to fix them manually

And in the tensorflow recognition onnx test (last test case in the file) please set the parseq test also same as above master and sar with the min ram check

@nikokks
Copy link
Contributor Author

nikokks commented Jun 13, 2023

make quality
isort . -c
Skipped 43 files
ruff check .
black --check .
All done! ✨ 🍰 ✨
246 files would be left unchanged.
mypy doctr/
doctr/__init__.py:3: error: Skipping analyzing "doctr.version": module is installed, but missing library stubs or py.typed marker  [import]
    from .version import __version__  # noqa: F401
    ^
doctr/__init__.py:3: note: See https://mypy.readthedocs.io/en/stable/running_mypy.html#missing-imports
doctr/models/recognition/parseq/tensorflow.py:206: error: No overload variant of "__add__" of "list" matches argument type "int"  [operator]
            combined = tf.concat([sos_idx, perms + 1, eos_idx], axis=1)
                                           ^~~~~~~~~
doctr/models/recognition/parseq/tensorflow.py:206: note: Possible overload variants:
doctr/models/recognition/parseq/tensorflow.py:206: note:     def __add__(self, List[Any], /) -> List[Any]
doctr/models/recognition/parseq/tensorflow.py:206: note:     def [_S] __add__(self, List[_S], /) -> List[Union[_S, Any]]
doctr/models/recognition/parseq/pytorch.py:183: error: Argument 1 to "factorial" has incompatible type "Union[int, float]"; expected "SupportsIndex"  [arg-type]
            max_perms = math.factorial(max_num_chars) // 2
                                       ^~~~~~~~~~~~~
doctr/models/recognition/parseq/pytorch.py:192: error: Argument 1 to "range" has incompatible type "Union[int, float]"; expected "SupportsIndex"  [arg-type]
                perm_pool = torch.as_tensor(list(permutations(range(max_num_chars), max_num_chars)), device=seqlen.device)[
                                                                    ^~~~~~~~~~~~~
doctr/models/recognition/parseq/pytorch.py:192: error: Argument 2 to "permutations" has incompatible type "Union[int, float]"; expected "Optional[int]"  [arg-type]
                perm_pool = torch.as_tensor(list(permutations(range(max_num_chars), max_num_chars)), device=seqlen.device)[
                                                                                    ^~~~~~~~~~~~~
doctr/models/recognition/parseq/pytorch.py:197: error: Incompatible types in assignment (expression has type "Tensor", variable has type "List[Tensor]")  [assignment]
                perms = torch.stack(perms)
                        ^~~~~~~~~~~~~~~~~~
doctr/models/recognition/parseq/pytorch.py:200: error: Incompatible types in assignment (expression has type "Tensor", variable has type "List[Tensor]")  [assignment]
                    perms = torch.cat([perms, perm_pool[i]])
                            ^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
doctr/models/recognition/parseq/pytorch.py:200: error: List item 0 has incompatible type "List[Tensor]"; expected "Tensor"  [list-item]
                    perms = torch.cat([perms, perm_pool[i]])
                                       ^~~~~
doctr/models/recognition/parseq/pytorch.py:200: error: Invalid index type "ndarray[Any, dtype[signedinteger[_64Bit]]]" for "Tensor"; expected type
"Union[None, int, slice, Tensor, List[Any], Tuple[Any, ...]]"  [index]
                    perms = torch.cat([perms, perm_pool[i]])
                                                        ^
doctr/models/recognition/parseq/pytorch.py:203: error: Argument 1 to "randperm" has incompatible type "Union[int, float]"; expected "int"  [arg-type]
                    [torch.randperm(max_num_chars, device=seqlen.device) for _ in range(num_gen_perms - len(perms))]
                                    ^~~~~~~~~~~~~
doctr/models/recognition/parseq/pytorch.py:205: error: Incompatible types in assignment (expression has type "Tensor", variable has type "List[Tensor]")  [assignment]
                perms = torch.stack(perms)
                        ^~~~~~~~~~~~~~~~~~
doctr/models/recognition/parseq/pytorch.py:207: error: "List[Tensor]" has no attribute "flip"  [attr-defined]
            comp = perms.flip(-1)
                   ^~~~~~~~~~
doctr/models/recognition/parseq/pytorch.py:208: error: Incompatible types in assignment (expression has type "Tensor", variable has type "List[Tensor]")  [assignment]
            perms = torch.stack([perms, comp]).transpose(0, 1).reshape(-1, max_num_chars)
                    ^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
doctr/models/recognition/parseq/pytorch.py:208: error: List item 0 has incompatible type "List[Tensor]"; expected "Tensor"  [list-item]
            perms = torch.stack([perms, comp]).transpose(0, 1).reshape(-1, max_num_chars)
                                 ^~~~~
doctr/models/recognition/parseq/pytorch.py:208: error: Argument 2 to "reshape" of "_TensorBase" has incompatible type "Union[int, float]"; expected "int"  [arg-type]
            perms = torch.stack([perms, comp]).transpose(0, 1).reshape(-1, max_num_chars)
                                                                           ^~~~~~~~~~~~~
doctr/models/recognition/parseq/pytorch.py:210: error: "List[Tensor]" has no attribute "device"  [attr-defined]
            sos_idx = torch.zeros(len(perms), 1, device=perms.device)
                                                        ^~~~~~~~~~~~
doctr/models/recognition/parseq/pytorch.py:211: error: "List[Tensor]" has no attribute "device"  [attr-defined]
            eos_idx = torch.full((len(perms), 1), max_num_chars + 1, device=perms.device)
                                                                            ^~~~~~~~~~~~
doctr/models/recognition/parseq/pytorch.py:212: error: No overload variant of "__add__" of "list" matches argument type "int"  [operator]
            combined = torch.cat([sos_idx, perms + 1, eos_idx], dim=1).int()
                                           ^~~~~~~~~
doctr/models/recognition/parseq/pytorch.py:212: note: Possible overload variants:
doctr/models/recognition/parseq/pytorch.py:212: note:     def __add__(self, List[Tensor], /) -> List[Tensor]
doctr/models/recognition/parseq/pytorch.py:212: note:     def [_S] __add__(self, List[_S], /) -> List[Union[_S, Tensor]]
doctr/models/recognition/parseq/pytorch.py:289: error: Incompatible types in assignment (expression has type "Tensor", variable has type "List[Any]")  [assignment]
            logits = torch.cat(logits, dim=1)  # (N, max_length, vocab_size + 1)
                     ^~~~~~~~~~~~~~~~~~~~~~~~
doctr/models/recognition/parseq/pytorch.py:301: error: "List[Any]" has no attribute "argmax"  [attr-defined]
            ys = torch.cat([sos, logits.argmax(-1)], dim=1)
                                 ^~~~~~~~~~~~~
doctr/models/recognition/parseq/pytorch.py:309: error: Incompatible return value type (got "List[Any]", expected "Tensor")  [return-value]
            return logits  # (N, max_length, vocab_size + 1)
                   ^~~~~~
Found 21 errors in 3 files (checked 154 source files)
make: *** [Makefile:7: quality] Error 1

@felixT2K
Copy link
Contributor

@nikokks should be fixed on my branch

@nikokks
Copy link
Contributor Author

nikokks commented Jun 13, 2023

@nikokks should be fixed on my branch

Ok for me for the last commit on my branch !! =)

@felixdittrich92 felixdittrich92 changed the title DRAFT: adding model PARSeq [Feat] Add PARSeq model TF and PT Jun 13, 2023
@felixdittrich92 felixdittrich92 marked this pull request as ready for review June 13, 2023 11:24
@codecov
Copy link

codecov bot commented Jun 13, 2023

Codecov Report

Merging #1205 (02bf218) into main (3bd6b3d) will decrease coverage by 1.04%.
The diff coverage is 79.34%.

@@            Coverage Diff             @@
##             main    #1205      +/-   ##
==========================================
- Coverage   94.73%   93.69%   -1.04%     
==========================================
  Files         150      154       +4     
  Lines        6458     6903     +445     
==========================================
+ Hits         6118     6468     +350     
- Misses        340      435      +95     
Flag Coverage Δ
unittests 93.69% <79.34%> (-1.04%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

Impacted Files Coverage Δ
doctr/models/recognition/parseq/tensorflow.py 76.85% <76.85%> (ø)
doctr/models/recognition/parseq/pytorch.py 78.36% <78.36%> (ø)
doctr/models/modules/transformer/pytorch.py 100.00% <100.00%> (ø)
doctr/models/modules/transformer/tensorflow.py 99.03% <100.00%> (ø)
doctr/models/recognition/__init__.py 100.00% <100.00%> (ø)
doctr/models/recognition/parseq/__init__.py 100.00% <100.00%> (ø)
doctr/models/recognition/parseq/base.py 100.00% <100.00%> (ø)
doctr/models/recognition/vitstr/base.py 100.00% <100.00%> (ø)
doctr/models/recognition/vitstr/pytorch.py 100.00% <100.00%> (ø)
doctr/models/recognition/vitstr/tensorflow.py 97.61% <100.00%> (ø)
... and 1 more

@felixdittrich92
Copy link
Contributor

Thanks @nikokks 👍 now it's fine lets wait for a final review :)

@odulcy-mindee
Copy link
Collaborator

Thank you @nikokks for the PR and @felixdittrich92 for this review, I'll have a look at it today !

Copy link
Collaborator

@odulcy-mindee odulcy-mindee left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for your contribution ! Really great work! Code seems fine to me.
I just have 2 questions, you can merge after that.

@felixdittrich92 a new model to add on the training list haha

Comment on lines +97 to +103
target = target.clone() + self.attention_dropout(
self.attention(query_norm, content_norm, content_norm, mask=target_mask)
)
target = target.clone() + self.cross_attention_dropout(
self.cross_attention(self.query_norm(target), memory, memory)
)
target = target.clone() + self.feed_forward_dropout(self.position_feed_forward(self.feed_forward_norm(target)))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are there clone calls here ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pytorch does not allow overriding inplace because it would raise problems in CUDA :)

The other Option would be to use 2 variables but i personally like the clone()
way -> minimizing to code a few lines 😅

Copy link
Contributor Author

@nikokks nikokks Jun 14, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can do an only one clone before line 97 and remove the other clones. I have tried before it worked. To verify to be sure. With that it should speed up processing.

Comment on lines +223 to +231
mask = torch.ones((sz, sz), device=permutation.device)

for i in range(sz):
query_idx = permutation[i]
masked_keys = permutation[i + 1 :]
mask[query_idx, masked_keys] = 0.0
source_mask = mask[:-1, :-1].clone()
mask[torch.eye(sz, dtype=torch.bool, device=permutation.device)] = 0.0
target_mask = mask[1:, :-1]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So we're using 0 and 1 for mask and authors used float(-inf) and 0 respectively, I'm a correct ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep for our MHA implementation 0 is masked (transformer decoder can't "see" it we replace it inside scaled dot product to -inf this masking is needed for Transformer decoder otherwise the model would be able to "cheat") and 1 is visible the softmax activation does the rest :)

To overcome the question why we don't set it directly to -inf
Short answer: this would raise problems on ONNX exporting :)

@felixdittrich92
Copy link
Contributor

Thank you for your contribution ! Really great work! Code seems fine to me. I just have 2 questions, you can merge after that.

@felixdittrich92 a new model to add on the training list haha

🙈 Yes but i would keep it as the last model on the list to train (I still try to debug some things but it's really hard these models with a transformer encoder as backbone needs a ton of data 🥲😅)

@felixdittrich92
Copy link
Contributor

I will do some tests if i am done i will merge it :)

odulcy-mindee
odulcy-mindee previously approved these changes Jun 14, 2023
Copy link
Collaborator

@odulcy-mindee odulcy-mindee left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okey, thanks for the answers !

@felixdittrich92 felixdittrich92 merged commit c09cc80 into mindee:main Jun 15, 2023
53 of 57 checks passed
Copy link

@baudm baudm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello, I'm the original author of PARSeq. @felixdittrich92 asked me to review your implementation. I'll add my comments here even though the PR is already closed.

Overall, it looks correct except for the masking + training loop. Good job and thanks for this initiative. :)

Comment on lines +354 to +356
logits = self.decode_non_autoregressive(features)
else:
logits = self.decode_autoregressive(features)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does target determine whether the inference mode is AR (target is None) or NAR (target is not None)?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey 👋 we have already changed it in https://github.com/felixdittrich92/doctr/tree/parseq-fixes (was more for Fürther debugging) :)

# Generate attention masks for the permutations
_, target_mask = self.generate_permutations_attention_masks(perm)
# combine target padding mask and query mask
mask = (target_mask & padding_mask).int()
Copy link

@baudm baudm Jun 16, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

target_mask, as generated by generate_permutations_attention_masks(), cannot be ANDed with the padding_mask. The mask generated from a permutation is shared across all sequences (shape: (max_len, max_len)), while the padding_mask varies for each sequence (shape: (N, max_len)). If you want to AND both masks, you have to tile the target_mask for each sequence such that it becomes (N, max_len, max_len), and tile padding_mask for each character output position, i.e. reshape to (N, 1, max_len) first then tile such that it becomes the same shape.

Personally, at least for PyTorch, it would be better to use the padding_mask and target_mask separately since this is handled automatically by the native MHA implementation.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@baudm but as you can see we use our own MHA implementation (it is a bit slower as the native implementation of course) but in the past we have had some trouble with the native implementation especially with onnx and it makes it much easier for us to port it to Tensorflow :)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pseudocode (not tested):

tiled_target_mask = target_mask.unsqueeze(0).repeat(gt.shape[0], 1, 1)  # (N, max_len, max_len)
 # tile padding mask for each character output position
tiled_padding_mask = padding_mask.reshape(gt.shape[0], 1, -1).repeat(1, tiled_target_mask.shape[1], 1)  # (N, max_len, max_len)
# reshape to (N, max_len, max_len) to match the shape of tiled_target_mask
tiled_target_mask = tiled_target_mask.reshape(gt.shape[0], 1, -1)
# combine target padding mask and query mask
mask = (tiled_target_mask & tiled_padding_mask).int()

@nikokks
@baudm correct me if i am wrong :)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@nikokks in this case we need to remove the padding inside generate_permutations function (See function return)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@baudm Does it have any impact if we pad the permutation list to max_length with the eos char to ensure size unified attention masks ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still stucking 😅

Another problem is how can we pad it to self.max_length

# Create padding mask for target input
# [True, True, True, ..., False, False, False] -> False is masked
padding_mask = ~(((gt == self.vocab_size + 2) | (gt == self.vocab_size)).int().cumsum(-1) > 0)

torch.set_printoptions(profile="full")
if self.training:
    # Generate permutations for the target sequences
    tgt_perms = self.generate_permutations(seq_len)
    print(f"Permutations: {tgt_perms}")

    loss = 0
    for perm in tgt_perms:
        print(f"Permutation: {perm}")
        # Generate attention mask for the permutation
        _, target_mask = self.generate_permutations_attention_masks(perm)
        key_padding_mask_expanded = padding_mask[:, :target_mask.shape[-1]].view(features.shape[0], 1, 1, target_mask.shape[-1]).expand(-1, 1, -1, -1)
        print(f"key_padding_mask_expanded shape: {key_padding_mask_expanded.shape}")
        print(f"key_padding_mask_expanded: \n{key_padding_mask_expanded}")
        target_mask = target_mask.view(1, 1, target_mask.shape[-1], target_mask.shape[-1]).expand(features.shape[0], 1, -1, -1)
        print(f"target_mask shape: {target_mask.shape}")
        print(f"target_mask: \n{target_mask}")
        mask = (key_padding_mask_expanded.bool() & target_mask.bool()).int()
        print(f"mask shape: {mask.shape}")
        print(f"mask: \n{mask}")

        logits = self.head(self.decode(gt[:, :target_mask.shape[-1]], features, mask))  # (N, max_length, vocab_size + 1)
        print(f"logits shape: {logits.shape}")
Permutations: tensor([[0, 1, 2, 3, 4, 5, 6, 7],
        [0, 7, 6, 5, 4, 3, 2, 1],
        [0, 3, 6, 2, 5, 1, 4, 7],
        [0, 4, 1, 5, 2, 6, 3, 7],
        [0, 1, 6, 3, 5, 2, 4, 7],
        [0, 4, 2, 5, 3, 6, 1, 7]], device='cuda:0', dtype=torch.int32)
Permutation: tensor([0, 1, 2, 3, 4, 5, 6, 7], device='cuda:0', dtype=torch.int32)
key_padding_mask_expanded shape: torch.Size([3, 1, 1, 7])
key_padding_mask_expanded: 
tensor([[[[ True,  True,  True,  True,  True,  True, False]]],


        [[[ True,  True,  True, False, False, False, False]]],


        [[[ True,  True,  True, False, False, False, False]]]],
       device='cuda:0')
target_mask shape: torch.Size([3, 1, 7, 7])
target_mask: 
tensor([[[[1, 0, 0, 0, 0, 0, 0],
          [1, 1, 0, 0, 0, 0, 0],
          [1, 1, 1, 0, 0, 0, 0],
          [1, 1, 1, 1, 0, 0, 0],
          [1, 1, 1, 1, 1, 0, 0],
          [1, 1, 1, 1, 1, 1, 0],
          [1, 1, 1, 1, 1, 1, 1]]],


        [[[1, 0, 0, 0, 0, 0, 0],
          [1, 1, 0, 0, 0, 0, 0],
          [1, 1, 1, 0, 0, 0, 0],
          [1, 1, 1, 1, 0, 0, 0],
          [1, 1, 1, 1, 1, 0, 0],
          [1, 1, 1, 1, 1, 1, 0],
          [1, 1, 1, 1, 1, 1, 1]]],


        [[[1, 0, 0, 0, 0, 0, 0],
          [1, 1, 0, 0, 0, 0, 0],
          [1, 1, 1, 0, 0, 0, 0],
          [1, 1, 1, 1, 0, 0, 0],
          [1, 1, 1, 1, 1, 0, 0],
          [1, 1, 1, 1, 1, 1, 0],
          [1, 1, 1, 1, 1, 1, 1]]]], device='cuda:0', dtype=torch.int32)
mask shape: torch.Size([3, 1, 7, 7])
mask: 
tensor([[[[1, 0, 0, 0, 0, 0, 0],
          [1, 1, 0, 0, 0, 0, 0],
          [1, 1, 1, 0, 0, 0, 0],
          [1, 1, 1, 1, 0, 0, 0],
          [1, 1, 1, 1, 1, 0, 0],
          [1, 1, 1, 1, 1, 1, 0],
          [1, 1, 1, 1, 1, 1, 0]]],


        [[[1, 0, 0, 0, 0, 0, 0],
          [1, 1, 0, 0, 0, 0, 0],
          [1, 1, 1, 0, 0, 0, 0],
          [1, 1, 1, 0, 0, 0, 0],
          [1, 1, 1, 0, 0, 0, 0],
          [1, 1, 1, 0, 0, 0, 0],
          [1, 1, 1, 0, 0, 0, 0]]],


        [[[1, 0, 0, 0, 0, 0, 0],
          [1, 1, 0, 0, 0, 0, 0],
          [1, 1, 1, 0, 0, 0, 0],
          [1, 1, 1, 0, 0, 0, 0],
          [1, 1, 1, 0, 0, 0, 0],
          [1, 1, 1, 0, 0, 0, 0],
          [1, 1, 1, 0, 0, 0, 0]]]], device='cuda:0', dtype=torch.int32)
logits shape: torch.Size([3, 7, 127])
Permutation: tensor([0, 7, 6, 5, 4, 3, 2, 1], device='cuda:0', dtype=torch.int32)
key_padding_mask_expanded shape: torch.Size([3, 1, 1, 7])
key_padding_mask_expanded: 
tensor([[[[ True,  True,  True,  True,  True,  True, False]]],


        [[[ True,  True,  True, False, False, False, False]]],


        [[[ True,  True,  True, False, False, False, False]]]],
       device='cuda:0')
target_mask shape: torch.Size([3, 1, 7, 7])
target_mask: 
tensor([[[[1, 0, 1, 1, 1, 1, 1],
          [1, 0, 0, 1, 1, 1, 1],
          [1, 0, 0, 0, 1, 1, 1],
          [1, 0, 0, 0, 0, 1, 1],
          [1, 0, 0, 0, 0, 0, 1],
          [1, 0, 0, 0, 0, 0, 0],
          [1, 0, 0, 0, 0, 0, 0]]],


        [[[1, 0, 1, 1, 1, 1, 1],
          [1, 0, 0, 1, 1, 1, 1],
          [1, 0, 0, 0, 1, 1, 1],
          [1, 0, 0, 0, 0, 1, 1],
          [1, 0, 0, 0, 0, 0, 1],
          [1, 0, 0, 0, 0, 0, 0],
          [1, 0, 0, 0, 0, 0, 0]]],


        [[[1, 0, 1, 1, 1, 1, 1],
          [1, 0, 0, 1, 1, 1, 1],
          [1, 0, 0, 0, 1, 1, 1],
          [1, 0, 0, 0, 0, 1, 1],
          [1, 0, 0, 0, 0, 0, 1],
          [1, 0, 0, 0, 0, 0, 0],
          [1, 0, 0, 0, 0, 0, 0]]]], device='cuda:0', dtype=torch.int32)
mask shape: torch.Size([3, 1, 7, 7])
mask: 
tensor([[[[1, 0, 1, 1, 1, 1, 0],
          [1, 0, 0, 1, 1, 1, 0],
          [1, 0, 0, 0, 1, 1, 0],
          [1, 0, 0, 0, 0, 1, 0],
          [1, 0, 0, 0, 0, 0, 0],
          [1, 0, 0, 0, 0, 0, 0],
          [1, 0, 0, 0, 0, 0, 0]]],


        [[[1, 0, 1, 0, 0, 0, 0],
          [1, 0, 0, 0, 0, 0, 0],
          [1, 0, 0, 0, 0, 0, 0],
          [1, 0, 0, 0, 0, 0, 0],
          [1, 0, 0, 0, 0, 0, 0],
          [1, 0, 0, 0, 0, 0, 0],
          [1, 0, 0, 0, 0, 0, 0]]],


        [[[1, 0, 1, 0, 0, 0, 0],
          [1, 0, 0, 0, 0, 0, 0],
          [1, 0, 0, 0, 0, 0, 0],
          [1, 0, 0, 0, 0, 0, 0],
          [1, 0, 0, 0, 0, 0, 0],
          [1, 0, 0, 0, 0, 0, 0],
          [1, 0, 0, 0, 0, 0, 0]]]], device='cuda:0', dtype=torch.int32)
logits shape: torch.Size([3, 7, 127])
Permutation: tensor([0, 3, 6, 2, 5, 1, 4, 7], device='cuda:0', dtype=torch.int32)
key_padding_mask_expanded shape: torch.Size([3, 1, 1, 7])
key_padding_mask_expanded: 
tensor([[[[ True,  True,  True,  True,  True,  True, False]]],


        [[[ True,  True,  True, False, False, False, False]]],


        [[[ True,  True,  True, False, False, False, False]]]],
       device='cuda:0')
target_mask shape: torch.Size([3, 1, 7, 7])
target_mask: 
tensor([[[[1, 0, 1, 1, 0, 1, 1],
          [1, 0, 0, 1, 0, 0, 1],
          [1, 0, 0, 0, 0, 0, 0],
          [1, 1, 1, 1, 0, 1, 1],
          [1, 0, 1, 1, 0, 0, 1],
          [1, 0, 0, 1, 0, 0, 0],
          [1, 1, 1, 1, 1, 1, 1]]],


        [[[1, 0, 1, 1, 0, 1, 1],
          [1, 0, 0, 1, 0, 0, 1],
          [1, 0, 0, 0, 0, 0, 0],
          [1, 1, 1, 1, 0, 1, 1],
          [1, 0, 1, 1, 0, 0, 1],
          [1, 0, 0, 1, 0, 0, 0],
          [1, 1, 1, 1, 1, 1, 1]]],


        [[[1, 0, 1, 1, 0, 1, 1],
          [1, 0, 0, 1, 0, 0, 1],
          [1, 0, 0, 0, 0, 0, 0],
          [1, 1, 1, 1, 0, 1, 1],
          [1, 0, 1, 1, 0, 0, 1],
          [1, 0, 0, 1, 0, 0, 0],
          [1, 1, 1, 1, 1, 1, 1]]]], device='cuda:0', dtype=torch.int32)
mask shape: torch.Size([3, 1, 7, 7])
mask: 
tensor([[[[1, 0, 1, 1, 0, 1, 0],
          [1, 0, 0, 1, 0, 0, 0],
          [1, 0, 0, 0, 0, 0, 0],
          [1, 1, 1, 1, 0, 1, 0],
          [1, 0, 1, 1, 0, 0, 0],
          [1, 0, 0, 1, 0, 0, 0],
          [1, 1, 1, 1, 1, 1, 0]]],


        [[[1, 0, 1, 0, 0, 0, 0],
          [1, 0, 0, 0, 0, 0, 0],
          [1, 0, 0, 0, 0, 0, 0],
          [1, 1, 1, 0, 0, 0, 0],
          [1, 0, 1, 0, 0, 0, 0],
          [1, 0, 0, 0, 0, 0, 0],
          [1, 1, 1, 0, 0, 0, 0]]],


        [[[1, 0, 1, 0, 0, 0, 0],
          [1, 0, 0, 0, 0, 0, 0],
          [1, 0, 0, 0, 0, 0, 0],
          [1, 1, 1, 0, 0, 0, 0],
          [1, 0, 1, 0, 0, 0, 0],
          [1, 0, 0, 0, 0, 0, 0],
          [1, 1, 1, 0, 0, 0, 0]]]], device='cuda:0', dtype=torch.int32)
logits shape: torch.Size([3, 7, 127])
Permutation: tensor([0, 4, 1, 5, 2, 6, 3, 7], device='cuda:0', dtype=torch.int32)
key_padding_mask_expanded shape: torch.Size([3, 1, 1, 7])
key_padding_mask_expanded: 
tensor([[[[ True,  True,  True,  True,  True,  True, False]]],


        [[[ True,  True,  True, False, False, False, False]]],


        [[[ True,  True,  True, False, False, False, False]]]],
       device='cuda:0')
target_mask shape: torch.Size([3, 1, 7, 7])
target_mask: 
tensor([[[[1, 0, 0, 0, 1, 0, 0],
          [1, 1, 0, 0, 1, 1, 0],
          [1, 1, 1, 0, 1, 1, 1],
          [1, 0, 0, 0, 0, 0, 0],
          [1, 1, 0, 0, 1, 0, 0],
          [1, 1, 1, 0, 1, 1, 0],
          [1, 1, 1, 1, 1, 1, 1]]],


        [[[1, 0, 0, 0, 1, 0, 0],
          [1, 1, 0, 0, 1, 1, 0],
          [1, 1, 1, 0, 1, 1, 1],
          [1, 0, 0, 0, 0, 0, 0],
          [1, 1, 0, 0, 1, 0, 0],
          [1, 1, 1, 0, 1, 1, 0],
          [1, 1, 1, 1, 1, 1, 1]]],


        [[[1, 0, 0, 0, 1, 0, 0],
          [1, 1, 0, 0, 1, 1, 0],
          [1, 1, 1, 0, 1, 1, 1],
          [1, 0, 0, 0, 0, 0, 0],
          [1, 1, 0, 0, 1, 0, 0],
          [1, 1, 1, 0, 1, 1, 0],
          [1, 1, 1, 1, 1, 1, 1]]]], device='cuda:0', dtype=torch.int32)
mask shape: torch.Size([3, 1, 7, 7])
mask: 
tensor([[[[1, 0, 0, 0, 1, 0, 0],
          [1, 1, 0, 0, 1, 1, 0],
          [1, 1, 1, 0, 1, 1, 0],
          [1, 0, 0, 0, 0, 0, 0],
          [1, 1, 0, 0, 1, 0, 0],
          [1, 1, 1, 0, 1, 1, 0],
          [1, 1, 1, 1, 1, 1, 0]]],


        [[[1, 0, 0, 0, 0, 0, 0],
          [1, 1, 0, 0, 0, 0, 0],
          [1, 1, 1, 0, 0, 0, 0],
          [1, 0, 0, 0, 0, 0, 0],
          [1, 1, 0, 0, 0, 0, 0],
          [1, 1, 1, 0, 0, 0, 0],
          [1, 1, 1, 0, 0, 0, 0]]],


        [[[1, 0, 0, 0, 0, 0, 0],
          [1, 1, 0, 0, 0, 0, 0],
          [1, 1, 1, 0, 0, 0, 0],
          [1, 0, 0, 0, 0, 0, 0],
          [1, 1, 0, 0, 0, 0, 0],
          [1, 1, 1, 0, 0, 0, 0],
          [1, 1, 1, 0, 0, 0, 0]]]], device='cuda:0', dtype=torch.int32)
logits shape: torch.Size([3, 7, 127])
Permutation: tensor([0, 1, 6, 3, 5, 2, 4, 7], device='cuda:0', dtype=torch.int32)
key_padding_mask_expanded shape: torch.Size([3, 1, 1, 7])
key_padding_mask_expanded: 
tensor([[[[ True,  True,  True,  True,  True,  True, False]]],


        [[[ True,  True,  True, False, False, False, False]]],


        [[[ True,  True,  True, False, False, False, False]]]],
       device='cuda:0')
target_mask shape: torch.Size([3, 1, 7, 7])
target_mask: 
tensor([[[[1, 0, 0, 0, 0, 0, 0],
          [1, 1, 0, 1, 0, 1, 1],
          [1, 1, 0, 0, 0, 0, 1],
          [1, 1, 1, 1, 0, 1, 1],
          [1, 1, 0, 1, 0, 0, 1],
          [1, 1, 0, 0, 0, 0, 0],
          [1, 1, 1, 1, 1, 1, 1]]],


        [[[1, 0, 0, 0, 0, 0, 0],
          [1, 1, 0, 1, 0, 1, 1],
          [1, 1, 0, 0, 0, 0, 1],
          [1, 1, 1, 1, 0, 1, 1],
          [1, 1, 0, 1, 0, 0, 1],
          [1, 1, 0, 0, 0, 0, 0],
          [1, 1, 1, 1, 1, 1, 1]]],


        [[[1, 0, 0, 0, 0, 0, 0],
          [1, 1, 0, 1, 0, 1, 1],
          [1, 1, 0, 0, 0, 0, 1],
          [1, 1, 1, 1, 0, 1, 1],
          [1, 1, 0, 1, 0, 0, 1],
          [1, 1, 0, 0, 0, 0, 0],
          [1, 1, 1, 1, 1, 1, 1]]]], device='cuda:0', dtype=torch.int32)
mask shape: torch.Size([3, 1, 7, 7])
mask: 
tensor([[[[1, 0, 0, 0, 0, 0, 0],
          [1, 1, 0, 1, 0, 1, 0],
          [1, 1, 0, 0, 0, 0, 0],
          [1, 1, 1, 1, 0, 1, 0],
          [1, 1, 0, 1, 0, 0, 0],
          [1, 1, 0, 0, 0, 0, 0],
          [1, 1, 1, 1, 1, 1, 0]]],


        [[[1, 0, 0, 0, 0, 0, 0],
          [1, 1, 0, 0, 0, 0, 0],
          [1, 1, 0, 0, 0, 0, 0],
          [1, 1, 1, 0, 0, 0, 0],
          [1, 1, 0, 0, 0, 0, 0],
          [1, 1, 0, 0, 0, 0, 0],
          [1, 1, 1, 0, 0, 0, 0]]],


        [[[1, 0, 0, 0, 0, 0, 0],
          [1, 1, 0, 0, 0, 0, 0],
          [1, 1, 0, 0, 0, 0, 0],
          [1, 1, 1, 0, 0, 0, 0],
          [1, 1, 0, 0, 0, 0, 0],
          [1, 1, 0, 0, 0, 0, 0],
          [1, 1, 1, 0, 0, 0, 0]]]], device='cuda:0', dtype=torch.int32)
logits shape: torch.Size([3, 7, 127])
Permutation: tensor([0, 4, 2, 5, 3, 6, 1, 7], device='cuda:0', dtype=torch.int32)
key_padding_mask_expanded shape: torch.Size([3, 1, 1, 7])
key_padding_mask_expanded: 
tensor([[[[ True,  True,  True,  True,  True,  True, False]]],


        [[[ True,  True,  True, False, False, False, False]]],


        [[[ True,  True,  True, False, False, False, False]]]],
       device='cuda:0')
target_mask shape: torch.Size([3, 1, 7, 7])
target_mask: 
tensor([[[[1, 0, 1, 1, 1, 1, 1],
          [1, 0, 0, 0, 1, 0, 0],
          [1, 0, 1, 0, 1, 1, 0],
          [1, 0, 0, 0, 0, 0, 0],
          [1, 0, 1, 0, 1, 0, 0],
          [1, 0, 1, 1, 1, 1, 0],
          [1, 1, 1, 1, 1, 1, 1]]],


        [[[1, 0, 1, 1, 1, 1, 1],
          [1, 0, 0, 0, 1, 0, 0],
          [1, 0, 1, 0, 1, 1, 0],
          [1, 0, 0, 0, 0, 0, 0],
          [1, 0, 1, 0, 1, 0, 0],
          [1, 0, 1, 1, 1, 1, 0],
          [1, 1, 1, 1, 1, 1, 1]]],


        [[[1, 0, 1, 1, 1, 1, 1],
          [1, 0, 0, 0, 1, 0, 0],
          [1, 0, 1, 0, 1, 1, 0],
          [1, 0, 0, 0, 0, 0, 0],
          [1, 0, 1, 0, 1, 0, 0],
          [1, 0, 1, 1, 1, 1, 0],
          [1, 1, 1, 1, 1, 1, 1]]]], device='cuda:0', dtype=torch.int32)
mask shape: torch.Size([3, 1, 7, 7])
mask: 
tensor([[[[1, 0, 1, 1, 1, 1, 0],
          [1, 0, 0, 0, 1, 0, 0],
          [1, 0, 1, 0, 1, 1, 0],
          [1, 0, 0, 0, 0, 0, 0],
          [1, 0, 1, 0, 1, 0, 0],
          [1, 0, 1, 1, 1, 1, 0],
          [1, 1, 1, 1, 1, 1, 0]]],


        [[[1, 0, 1, 0, 0, 0, 0],
          [1, 0, 0, 0, 0, 0, 0],
          [1, 0, 1, 0, 0, 0, 0],
          [1, 0, 0, 0, 0, 0, 0],
          [1, 0, 1, 0, 0, 0, 0],
          [1, 0, 1, 0, 0, 0, 0],
          [1, 1, 1, 0, 0, 0, 0]]],


        [[[1, 0, 1, 0, 0, 0, 0],
          [1, 0, 0, 0, 0, 0, 0],
          [1, 0, 1, 0, 0, 0, 0],
          [1, 0, 0, 0, 0, 0, 0],
          [1, 0, 1, 0, 0, 0, 0],
          [1, 0, 1, 0, 0, 0, 0],
          [1, 1, 1, 0, 0, 0, 0]]]], device='cuda:0', dtype=torch.int32)
logits shape: torch.Size([3, 7, 127])

doctr/models/recognition/parseq/pytorch.py Show resolved Hide resolved
Comment on lines +291 to +307
# One refine iteration
# Update query mask
query_mask[
torch.triu(
torch.ones(self.max_length + 1, self.max_length + 1, dtype=torch.bool, device=features.device), 2
)
] = 1

# Prepare target input for 1 refine iteration
sos = torch.full((features.size(0), 1), self.vocab_size + 1, dtype=torch.long, device=features.device)
ys = torch.cat([sos, logits[:, :-1].argmax(-1)], dim=1)

# Create padding mask for refined target input maskes all behind EOS token as False
# (N, 1, 1, max_length)
target_pad_mask = ~((ys == self.vocab_size).int().cumsum(-1) > 0).unsqueeze(1).unsqueeze(1)
mask = (target_pad_mask.bool() & query_mask[:, : ys.shape[1]].bool()).int()
logits = self.head(self.decode(ys, features, mask, target_query=pos_queries))
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The refinement process can be done regardless of the initial decoding scheme (AR or NAR). I suggest moving this to a separate method so it can be used by either AR or NAR decoding.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ext: docs Related to docs folder ext: tests Related to tests folder framework: pytorch Related to PyTorch backend module: models Related to doctr.models topic: documentation Improvements or additions to documentation topic: text recognition Related to the task of text recognition type: new feature New feature
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants