Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid overriding model_type in TasksManager #1647

Merged
merged 13 commits into from
Feb 6, 2024
2 changes: 1 addition & 1 deletion optimum/commands/export/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def parse_args_onnx(parser):
optional_group.add_argument(
"--library-name",
type=str,
choices=["transformers", "diffusers", "timm"],
choices=["transformers", "diffusers", "timm", "sentence_transformers"],
default=None,
help=("The library on the model." " If not provided, will attempt to infer the local checkpoint's library"),
)
Expand Down
35 changes: 25 additions & 10 deletions optimum/exporters/onnx/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,17 +68,16 @@ def _get_submodels_and_onnx_configs(
custom_onnx_configs: Dict,
custom_architecture: bool,
_variant: str,
library_name: str,
int_dtype: str = "int64",
float_dtype: str = "fp32",
fn_get_submodels: Optional[Callable] = None,
preprocessors: Optional[List[Any]] = None,
legacy: bool = False,
library_name: str = "transformers",
model_kwargs: Optional[Dict] = None,
):
is_stable_diffusion = "stable-diffusion" in task
if not custom_architecture:
if is_stable_diffusion:
if library_name == "diffusers":
JingyaHuang marked this conversation as resolved.
Show resolved Hide resolved
onnx_config = None
models_and_onnx_configs = get_stable_diffusion_models_for_export(
model, int_dtype=int_dtype, float_dtype=float_dtype
Expand Down Expand Up @@ -129,7 +128,7 @@ def _get_submodels_and_onnx_configs(
if fn_get_submodels is not None:
submodels_for_export = fn_get_submodels(model)
else:
if is_stable_diffusion:
if library_name == "diffusers":
submodels_for_export = _get_submodels_for_export_stable_diffusion(model)
elif (
model.config.is_encoder_decoder
Expand Down Expand Up @@ -373,12 +372,16 @@ def main_export(

if model_type not in TasksManager._SUPPORTED_MODEL_TYPE:
custom_architecture = True
elif task not in TasksManager.get_supported_tasks_for_model_type(model_type, "onnx"):
elif task not in TasksManager.get_supported_tasks_for_model_type(
model_type, "onnx", library_name=library_name
):
if original_task == "auto":
autodetected_message = " (auto-detected)"
else:
autodetected_message = ""
model_tasks = TasksManager.get_supported_tasks_for_model_type(model_type, exporter="onnx")
model_tasks = TasksManager.get_supported_tasks_for_model_type(
model_type, exporter="onnx", library_name=library_name
)
raise ValueError(
f"Asked to export a {model_type} model for the task {task}{autodetected_message}, but the Optimum ONNX exporter only supports the tasks {', '.join(model_tasks.keys())} for {model_type}. Please use a supported task. Please open an issue at https://github.com/huggingface/optimum/issues if you would like the task {task} to be supported in the ONNX export for {model_type}."
)
Expand Down Expand Up @@ -422,7 +425,13 @@ def main_export(
"Could not infer the pad token id, which is needed in this case, please provide it with the --pad_token_id argument"
)

model_type = "stable-diffusion" if "stable-diffusion" in task else model.config.model_type.replace("_", "-")
if "stable-diffusion" in task:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about stable-diffusion-xl don't we need an extra case for it?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

model_type = "stable-diffusion"
elif hasattr(model.config, "export_model_type"):
model_type = model.config.export_model_type.replace("_", "-")
else:
model_type = model.config.model_type.replace("_", "-")

if (
not custom_architecture
and library_name != "diffusers"
Expand Down Expand Up @@ -513,14 +522,20 @@ def onnx_export(
else:
float_dtype = "fp32"

model_type = "stable-diffusion" if library_name == "diffusers" else model.config.model_type.replace("_", "-")
if "stable-diffusion" in task:
model_type = "stable-diffusion"
elif hasattr(model.config, "export_model_type"):
model_type = model.config.export_model_type.replace("_", "-")
else:
model_type = model.config.model_type.replace("_", "-")

custom_architecture = library_name == "transformers" and model_type not in TasksManager._SUPPORTED_MODEL_TYPE
task = TasksManager.map_from_synonym(task)

# TODO: support onnx_config.py in the model repo
if custom_architecture and custom_onnx_configs is None:
raise ValueError(
f"Trying to export a {model.config.model_type} model, that is a custom or unsupported architecture, but no custom onnx configuration was passed as `custom_onnx_configs`. Please refer to https://huggingface.co/docs/optimum/main/en/exporters/onnx/usage_guides/export_a_model#custom-export-of-transformers-models for an example on how to export custom models. Please open an issue at https://github.com/huggingface/optimum/issues if you would like the model type {model.config.model_type} to be supported natively in the ONNX export."
f"Trying to export a {model_type} model, that is a custom or unsupported architecture, but no custom onnx configuration was passed as `custom_onnx_configs`. Please refer to https://huggingface.co/docs/optimum/main/en/exporters/onnx/usage_guides/export_a_model#custom-export-of-transformers-models for an example on how to export custom models. Please open an issue at https://github.com/huggingface/optimum/issues if you would like the model type {model_type} to be supported natively in the ONNX export."
)

if task is None:
Expand Down Expand Up @@ -690,7 +705,7 @@ def onnx_export(
if library_name == "diffusers":
# TODO: fix Can't pickle local object 'get_stable_diffusion_models_for_export.<locals>.<lambda>'
use_subprocess = False
elif model.config.model_type in UNPICKABLE_ARCHS:
elif model_type in UNPICKABLE_ARCHS:
# Pickling is bugged for nn.utils.weight_norm: https://github.com/pytorch/pytorch/issues/102983
# TODO: fix "Cowardly refusing to serialize non-leaf tensor" error for wav2vec2-conformer
use_subprocess = False
Expand Down
10 changes: 8 additions & 2 deletions optimum/exporters/onnx/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,10 @@ def __init__(

# Set up the encoder ONNX config.
encoder_onnx_config_constructor = TasksManager.get_exporter_config_constructor(
exporter="onnx", task="feature-extraction", model_type=config.encoder.model_type
exporter="onnx",
task="feature-extraction",
model_type=config.encoder.model_type,
library_name="transformers",
)
self._encoder_onnx_config = encoder_onnx_config_constructor(
config.encoder, int_dtype=int_dtype, float_dtype=float_dtype, preprocessors=preprocessors
Expand All @@ -353,7 +356,10 @@ def __init__(

# Set up the decoder ONNX config.
decoder_onnx_config_constructor = TasksManager.get_exporter_config_constructor(
exporter="onnx", task="feature-extraction", model_type=config.decoder.model_type
exporter="onnx",
task="feature-extraction",
model_type=config.decoder.model_type,
library_name="transformers",
)
kwargs = {}
if issubclass(decoder_onnx_config_constructor.func, OnnxConfigWithPast):
Expand Down
5 changes: 5 additions & 0 deletions optimum/exporters/onnx/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,7 @@ def get_stable_diffusion_models_for_export(
text_encoder_config_constructor = TasksManager.get_exporter_config_constructor(
model=pipeline.text_encoder,
exporter="onnx",
library_name="diffusers",
JingyaHuang marked this conversation as resolved.
Show resolved Hide resolved
task="feature-extraction",
)
text_encoder_onnx_config = text_encoder_config_constructor(
Expand All @@ -334,6 +335,7 @@ def get_stable_diffusion_models_for_export(
onnx_config_constructor = TasksManager.get_exporter_config_constructor(
model=pipeline.unet,
exporter="onnx",
library_name="diffusers",
task="semantic-segmentation",
model_type="unet",
)
Expand All @@ -345,6 +347,7 @@ def get_stable_diffusion_models_for_export(
vae_config_constructor = TasksManager.get_exporter_config_constructor(
model=vae_encoder,
exporter="onnx",
library_name="diffusers",
task="semantic-segmentation",
model_type="vae-encoder",
)
Expand All @@ -356,6 +359,7 @@ def get_stable_diffusion_models_for_export(
vae_config_constructor = TasksManager.get_exporter_config_constructor(
model=vae_decoder,
exporter="onnx",
library_name="diffusers",
task="semantic-segmentation",
model_type="vae-decoder",
)
Expand All @@ -366,6 +370,7 @@ def get_stable_diffusion_models_for_export(
onnx_config_constructor = TasksManager.get_exporter_config_constructor(
model=pipeline.text_encoder_2,
exporter="onnx",
library_name="diffusers",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same

task="feature-extraction",
model_type="clip-text-with-projection",
)
Expand Down