Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Layoutlm onnx support (Issue #13300) #13562

Merged
merged 30 commits into from
Sep 21, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
a14d851
Add support for exporting PyTorch LayoutLM to ONNX
nishprabhu Aug 27, 2021
21a845f
Added tests for converting LayoutLM to ONNX
nishprabhu Aug 27, 2021
68bb7c1
Merge branch 'huggingface:master' into layoutlm-onnx-support
nishprabhu Aug 27, 2021
402d542
Add support for exporting PyTorch LayoutLM to ONNX
nishprabhu Aug 27, 2021
f835ae7
Added tests for converting LayoutLM to ONNX
nishprabhu Aug 27, 2021
7a4f00d
cleanup
nishprabhu Aug 27, 2021
0424055
Merge branch 'layoutlm-onnx-support' of https://github.com/nishprabhu…
nishprabhu Aug 27, 2021
9662261
Removed regression/ folder
nishprabhu Aug 27, 2021
056d085
Add support for exporting PyTorch LayoutLM to ONNX
nishprabhu Aug 27, 2021
aedb3d1
Added tests for converting LayoutLM to ONNX
nishprabhu Aug 27, 2021
6153781
cleanup
nishprabhu Aug 27, 2021
cb28af7
Fixed import error
nishprabhu Aug 31, 2021
f815649
Fixed merge conflicts in configuration_layoutlm.py
nishprabhu Aug 31, 2021
f6a78cf
Remove unnecessary import statements
nishprabhu Aug 31, 2021
c99c6f3
Changed max_2d_positions from class variable to instance variable of …
nishprabhu Sep 1, 2021
4214721
Add support for exporting PyTorch LayoutLM to ONNX
nishprabhu Aug 27, 2021
e0a1d86
Added tests for converting LayoutLM to ONNX
nishprabhu Aug 27, 2021
b94f383
cleanup
nishprabhu Aug 27, 2021
36b6024
Add support for exporting PyTorch LayoutLM to ONNX
nishprabhu Aug 27, 2021
9162f5d
cleanup
nishprabhu Aug 27, 2021
9fe974b
Fixed import error
nishprabhu Aug 31, 2021
b5010a9
Changed max_2d_positions from class variable to instance variable of …
nishprabhu Sep 1, 2021
ebf4f4a
Merge branch 'layoutlm-onnx-support' of https://github.com/nishprabhu…
nishprabhu Sep 1, 2021
5993a2d
Merge branch 'huggingface:master' into layoutlm-onnx-support
nishprabhu Sep 14, 2021
6e37eb8
Use super class generate_dummy_inputs method
nishprabhu Sep 15, 2021
e3f33ee
Add support for Masked LM, sequence classification and token classifi…
nishprabhu Sep 15, 2021
84277c9
Removed uncessary import and method
nishprabhu Sep 15, 2021
cdcf3b7
Fixed code styling
nishprabhu Sep 15, 2021
c569d33
Raise error if PyTorch is not installed
nishprabhu Sep 15, 2021
381a040
Remove unnecessary import statement
nishprabhu Sep 15, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/transformers/models/layoutlm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@


_import_structure = {
"configuration_layoutlm": ["LAYOUTLM_PRETRAINED_CONFIG_ARCHIVE_MAP", "LayoutLMConfig"],
"configuration_layoutlm": ["LAYOUTLM_PRETRAINED_CONFIG_ARCHIVE_MAP", "LayoutLMConfig", "LayoutLMOnnxConfig"],
"tokenization_layoutlm": ["LayoutLMTokenizer"],
}

Expand Down Expand Up @@ -54,7 +54,7 @@


if TYPE_CHECKING:
from .configuration_layoutlm import LAYOUTLM_PRETRAINED_CONFIG_ARCHIVE_MAP, LayoutLMConfig
from .configuration_layoutlm import LAYOUTLM_PRETRAINED_CONFIG_ARCHIVE_MAP, LayoutLMConfig, LayoutLMOnnxConfig
from .tokenization_layoutlm import LayoutLMTokenizer

if is_tokenizers_available():
Expand Down
70 changes: 70 additions & 0 deletions src/transformers/models/layoutlm/configuration_layoutlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
""" LayoutLM model configuration """
from collections import OrderedDict
from typing import Any, List, Mapping, Optional

from transformers import PretrainedConfig, PreTrainedTokenizer, TensorType

from ... import is_torch_available
from ...onnx import OnnxConfig, PatchingSpec
from ...utils import logging
from ..bert.configuration_bert import BertConfig

Expand Down Expand Up @@ -125,3 +130,68 @@ def __init__(
**kwargs,
)
self.max_2d_position_embeddings = max_2d_position_embeddings


class LayoutLMOnnxConfig(OnnxConfig):
def __init__(
self,
config: PretrainedConfig,
task: str = "default",
patching_specs: List[PatchingSpec] = None,
):
super().__init__(config, task=task, patching_specs=patching_specs)
self.max_2d_positions = config.max_2d_position_embeddings - 1

@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
return OrderedDict(
[
("input_ids", {0: "batch", 1: "sequence"}),
("bbox", {0: "batch", 1: "sequence"}),
("attention_mask", {0: "batch", 1: "sequence"}),
("token_type_ids", {0: "batch", 1: "sequence"}),
]
)

def generate_dummy_inputs(
self,
tokenizer: PreTrainedTokenizer,
batch_size: int = -1,
seq_length: int = -1,
is_pair: bool = False,
framework: Optional[TensorType] = None,
) -> Mapping[str, Any]:
"""
Generate inputs to provide to the ONNX exporter for the specific framework

Args:
tokenizer: The tokenizer associated with this model configuration
batch_size: The batch size (int) to export the model for (-1 means dynamic axis)
seq_length: The sequence length (int) to export the model for (-1 means dynamic axis)
is_pair: Indicate if the input is a pair (sentence 1, sentence 2)
framework: The framework (optional) the tokenizer will generate tensor for

Returns:
Mapping[str, Tensor] holding the kwargs to provide to the model's forward function
"""

input_dict = super().generate_dummy_inputs(tokenizer, batch_size, seq_length, is_pair, framework)

# Generate a dummy bbox
box = [48, 84, 73, 128]

if not framework == TensorType.PYTORCH:
raise NotImplementedError("Exporting LayoutLM to ONNX is currently only supported for PyTorch.")

if not is_torch_available():
raise ValueError("Cannot generate dummy inputs without PyTorch installed.")
import torch
nishprabhu marked this conversation as resolved.
Show resolved Hide resolved

input_dict["bbox"] = torch.tensor(
[
[0] * 4,
*[box] * seq_length,
[self.max_2d_positions] * 4,
]
).tile(batch_size, 1, 1)
return input_dict
8 changes: 8 additions & 0 deletions src/transformers/onnx/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from ..models.distilbert import DistilBertOnnxConfig
from ..models.gpt2 import GPT2OnnxConfig
from ..models.gpt_neo import GPTNeoOnnxConfig
from ..models.layoutlm import LayoutLMOnnxConfig
from ..models.longformer import LongformerOnnxConfig
from ..models.mbart import MBartOnnxConfig
from ..models.roberta import RobertaOnnxConfig
Expand Down Expand Up @@ -78,6 +79,13 @@ class FeaturesManager:
"sequence-classification-with-past",
onnx_config_cls=GPTNeoOnnxConfig,
),
"layoutlm": supported_features_mapping(
"default",
"masked-lm",
"sequence-classification",
"token-classification",
onnx_config_cls=LayoutLMOnnxConfig,
),
}

AVAILABLE_FEATURES = sorted(reduce(lambda s1, s2: s1 | s2, (v.keys() for v in _SUPPORTED_MODEL_KIND.values())))
Expand Down
4 changes: 4 additions & 0 deletions tests/test_onnx_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
DistilBertConfig,
GPT2Config,
GPTNeoConfig,
LayoutLMConfig,
MBartConfig,
RobertaConfig,
XLMRobertaConfig,
Expand All @@ -23,6 +24,7 @@
# from transformers.models.longformer import LongformerOnnxConfig
from transformers.models.gpt2 import GPT2OnnxConfig
from transformers.models.gpt_neo import GPTNeoOnnxConfig
from transformers.models.layoutlm import LayoutLMOnnxConfig
from transformers.models.mbart import MBartOnnxConfig
from transformers.models.roberta import RobertaOnnxConfig

Expand Down Expand Up @@ -193,6 +195,7 @@ def test_values_override(self):
DistilBertModel,
GPT2Model,
GPTNeoModel,
LayoutLMModel,
MBartModel,
RobertaModel,
XLMRobertaModel,
Expand All @@ -208,6 +211,7 @@ def test_values_override(self):
# ("LongFormer", "longformer-base-4096", LongformerModel, LongformerConfig, LongformerOnnxConfig),
("Roberta", "roberta-base", RobertaModel, RobertaConfig, RobertaOnnxConfig),
("XLM-Roberta", "roberta-base", XLMRobertaModel, XLMRobertaConfig, XLMRobertaOnnxConfig),
("LayoutLM", "microsoft/layoutlm-base-uncased", LayoutLMModel, LayoutLMConfig, LayoutLMOnnxConfig),
("MBart", "sshleifer/tiny-mbart", MBartModel, MBartConfig, MBartOnnxConfig),
# ("T5", "t5-small", T5Model, T5Config, T5OnnxConfig),
}
Expand Down