Skip to content

Commit

Permalink
Fix tflite export (#1767)
Browse files Browse the repository at this point in the history
* Set image-feature-extraction as synonym of feature-extraction

* format
  • Loading branch information
echarlaix authored Mar 22, 2024
1 parent 1146eae commit 66e30ad
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 25 deletions.
1 change: 0 additions & 1 deletion optimum/exporters/onnx/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,6 @@ class OnnxConfig(ExportConfig, ABC):
"feature-extraction": OrderedDict({"last_hidden_state": {0: "batch_size", 1: "sequence_length"}}),
"fill-mask": OrderedDict({"logits": {0: "batch_size", 1: "sequence_length"}}),
"image-classification": OrderedDict({"logits": {0: "batch_size"}}),
"image-feature-extraction": OrderedDict({"last_hidden_state": {0: "batch_size", 1: "sequence_length"}}),
"image-segmentation": OrderedDict({"logits": {0: "batch_size", 1: "num_labels", 2: "height", 3: "width"}}),
"image-to-text": OrderedDict({"logits": {0: "batch_size", 1: "sequence_length"}}),
"image-to-image": OrderedDict(
Expand Down
28 changes: 4 additions & 24 deletions optimum/exporters/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,6 @@ class TasksManager:
"feature-extraction": "AutoModel",
"fill-mask": "AutoModelForMaskedLM",
"image-classification": "AutoModelForImageClassification",
"image-feature-extraction": "AutoModel",
"image-segmentation": ("AutoModelForImageSegmentation", "AutoModelForSemanticSegmentation"),
"image-to-image": "AutoModelForImageToImage",
"image-to-text": "AutoModelForVision2Seq",
Expand Down Expand Up @@ -257,6 +256,7 @@ class TasksManager:
"translation": "text2text-generation",
"vision2seq-lm": "image-to-text",
"zero-shot-classification": "text-classification",
"image-feature-extraction": "feature-extraction",
}

# Reverse dictionaries str -> str, where several model loaders may map to the same task
Expand Down Expand Up @@ -463,13 +463,11 @@ class TasksManager:
),
"convnext": supported_tasks_mapping(
"feature-extraction",
"image-feature-extraction",
"image-classification",
onnx="ConvNextOnnxConfig",
),
"convnextv2": supported_tasks_mapping(
"feature-extraction",
"image-feature-extraction",
"image-classification",
onnx="ConvNextV2OnnxConfig",
),
Expand All @@ -486,7 +484,6 @@ class TasksManager:
"data2vec-vision": supported_tasks_mapping(
"feature-extraction",
"image-classification",
"image-feature-extraction",
# ONNX doesn't support `adaptive_avg_pool2d` yet
# "semantic-segmentation",
onnx="Data2VecVisionOnnxConfig",
Expand Down Expand Up @@ -520,15 +517,13 @@ class TasksManager:
),
"deit": supported_tasks_mapping(
"feature-extraction",
"image-feature-extraction",
"image-classification",
"masked-im",
onnx="DeiTOnnxConfig",
),
"detr": supported_tasks_mapping(
"feature-extraction",
"object-detection",
"image-feature-extraction",
"image-segmentation",
onnx="DetrOnnxConfig",
),
Expand All @@ -555,7 +550,6 @@ class TasksManager:
),
"dpt": supported_tasks_mapping(
"feature-extraction",
"image-feature-extraction",
"depth-estimation",
"image-segmentation",
"semantic-segmentation",
Expand Down Expand Up @@ -614,7 +608,6 @@ class TasksManager:
),
"glpn": supported_tasks_mapping(
"feature-extraction",
"image-feature-extraction",
"depth-estimation",
onnx="GlpnOnnxConfig",
),
Expand Down Expand Up @@ -682,7 +675,6 @@ class TasksManager:
),
"imagegpt": supported_tasks_mapping(
"feature-extraction",
"image-feature-extraction",
"image-classification",
onnx="ImageGPTOnnxConfig",
),
Expand Down Expand Up @@ -714,9 +706,7 @@ class TasksManager:
"token-classification",
onnx="LiltOnnxConfig",
),
"levit": supported_tasks_mapping(
"feature-extraction", "image-classification", "image-feature-extraction", onnx="LevitOnnxConfig"
),
"levit": supported_tasks_mapping("feature-extraction", "image-classification", onnx="LevitOnnxConfig"),
"longt5": supported_tasks_mapping(
"feature-extraction",
"feature-extraction-with-past",
Expand Down Expand Up @@ -780,19 +770,17 @@ class TasksManager:
"mobilevit": supported_tasks_mapping(
"feature-extraction",
"image-classification",
"image-feature-extraction",
"image-segmentation",
onnx="MobileViTOnnxConfig",
),
"mobilenet-v1": supported_tasks_mapping(
"feature-extraction",
"image-feature-extraction",
"image-classification",
onnx="MobileNetV1OnnxConfig",
),
"mobilenet-v2": supported_tasks_mapping(
"feature-extraction",
"image-classification",
"image-feature-extraction",
onnx="MobileNetV2OnnxConfig",
),
"mpnet": supported_tasks_mapping(
Expand Down Expand Up @@ -901,20 +889,17 @@ class TasksManager:
),
"poolformer": supported_tasks_mapping(
"feature-extraction",
"image-feature-extraction",
"image-classification",
onnx="PoolFormerOnnxConfig",
),
"regnet": supported_tasks_mapping(
"feature-extraction",
"image-feature-extraction",
"image-classification",
onnx="RegNetOnnxConfig",
),
"resnet": supported_tasks_mapping(
"feature-extraction",
"image-classification",
"image-feature-extraction",
onnx="ResNetOnnxConfig",
tflite="ResNetTFLiteConfig",
),
Expand Down Expand Up @@ -950,7 +935,6 @@ class TasksManager:
"segformer": supported_tasks_mapping(
"feature-extraction",
"image-classification",
"image-feature-extraction",
"image-segmentation",
"semantic-segmentation",
onnx="SegformerOnnxConfig",
Expand Down Expand Up @@ -995,14 +979,12 @@ class TasksManager:
),
"swin": supported_tasks_mapping(
"feature-extraction",
"image-feature-extraction",
"image-classification",
"masked-im",
onnx="SwinOnnxConfig",
),
"swin2sr": supported_tasks_mapping(
"feature-extraction",
"image-feature-extraction",
"image-to-image",
onnx="Swin2srOnnxConfig",
),
Expand All @@ -1015,7 +997,6 @@ class TasksManager:
),
"table-transformer": supported_tasks_mapping(
"feature-extraction",
"image-feature-extraction",
"object-detection",
onnx="TableTransformerOnnxConfig",
),
Expand Down Expand Up @@ -1048,7 +1029,7 @@ class TasksManager:
onnx="VisionEncoderDecoderOnnxConfig",
),
"vit": supported_tasks_mapping(
"feature-extraction", "image-classification", "image-feature-extraction", "masked-im", onnx="ViTOnnxConfig"
"feature-extraction", "image-classification", "masked-im", onnx="ViTOnnxConfig"
),
"wavlm": supported_tasks_mapping(
"feature-extraction",
Expand Down Expand Up @@ -1108,7 +1089,6 @@ class TasksManager:
),
"yolos": supported_tasks_mapping(
"feature-extraction",
"image-feature-extraction",
"object-detection",
onnx="YolosOnnxConfig",
),
Expand Down

0 comments on commit 66e30ad

Please sign in to comment.