Skip to content

Commit

Permalink
Added TFLite Export for LayoutLMv3
Browse files Browse the repository at this point in the history
  • Loading branch information
salmanmaq committed Apr 26, 2024
1 parent c55f882 commit 43fbefb
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 3 deletions.
1 change: 1 addition & 0 deletions optimum/exporters/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -700,6 +700,7 @@ class TasksManager:
"text-classification",
"token-classification",
onnx="LayoutLMv3OnnxConfig",
tflite="LayoutLMv3TFLiteConfig",
),
"lilt": supported_tasks_mapping(
"feature-extraction",
Expand Down
11 changes: 10 additions & 1 deletion optimum/exporters/tflite/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
configurations.
"""

from ...utils import DummyTextInputGenerator, DummyVisionInputGenerator, logging
from ...utils import DummyBboxInputGenerator, DummyTextInputGenerator, DummyVisionInputGenerator, logging
from .base import TFLiteConfig


Expand All @@ -40,3 +40,12 @@ class VisionTFLiteConfig(TFLiteConfig):

DUMMY_INPUT_GENERATOR_CLASSES = (DummyVisionInputGenerator,)
MANDATORY_AXES = ("batch_size", "num_channels", "width", "height")


class TextAndVisionTFLiteConfig(TFLiteConfig):
"""
Handles multi-modal text and vision architectures.
"""

DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, DummyVisionInputGenerator, DummyBboxInputGenerator)
MANDATORY_AXES = ("batch_size", "sequence_length", "num_channels", "width", "height")
16 changes: 14 additions & 2 deletions optimum/exporters/tflite/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@

from typing import List

from ...utils.normalized_config import NormalizedConfigManager
from ...utils.normalized_config import NormalizedConfigManager, NormalizedTextConfig
from .base import QuantizationApproach
from .config import TextEncoderTFliteConfig, VisionTFLiteConfig
from .config import TextAndVisionTFLiteConfig, TextEncoderTFliteConfig, VisionTFLiteConfig


class BertTFLiteConfig(TextEncoderTFliteConfig):
Expand Down Expand Up @@ -124,3 +124,15 @@ class ResNetTFLiteConfig(VisionTFLiteConfig):
@property
def inputs(self) -> List[str]:
return ["pixel_values"]


class LayoutLMv3TFLiteConfig(TextAndVisionTFLiteConfig):
SUPPORTED_QUANTIZATION_APPROACHES = (QuantizationApproach.INT8_DYNAMIC, QuantizationApproach.FP16)

NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args(
allow_new=True, MAX_2D_POSITION_EMBEDDINGS="max_2d_position_embeddings", image_size="input_size"
)

@property
def inputs(self) -> List[str]:
return ["input_ids", "attention_mask", "bbox", "pixel_values"]

0 comments on commit 43fbefb

Please sign in to comment.