Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/transformers/onnx/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ def main():
parser.add_argument(
"--atol", type=float, default=None, help="Absolute difference tolerence when validating the model."
)
parser.add_argument(
"--framework", type=str, choices=["pt", "tf"], default="pt", help="The framework to use for the ONNX export."
)
parser.add_argument("output", type=Path, help="Path indicating where to store generated ONNX model.")

# Retrieve CLI arguments
Expand All @@ -58,7 +61,7 @@ def main():
raise ValueError(f"Unsupported model type: {config.model_type}")

# Allocate the model
model = FeaturesManager.get_model_from_feature(args.feature, args.model)
model = FeaturesManager.get_model_from_feature(args.feature, args.model, framework=args.framework)
model_kind, model_onnx_config = FeaturesManager.check_supported_model_or_raise(model, feature=args.feature)
onnx_config = model_onnx_config(model.config)

Expand Down
78 changes: 54 additions & 24 deletions src/transformers/onnx/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
AutoModelForSequenceClassification,
AutoModelForTokenClassification,
)
elif is_tf_available():
if is_tf_available():
from transformers.models.auto import (
TFAutoModel,
TFAutoModelForCausalLM,
Expand All @@ -48,7 +48,7 @@
TFAutoModelForSequenceClassification,
TFAutoModelForTokenClassification,
)
else:
if not is_torch_available() and not is_tf_available():
logger.warning(
"The ONNX export features are only supported for PyTorch or TensorFlow. You will not be able to export models without one of these libraries installed."
)
Expand Down Expand Up @@ -82,6 +82,8 @@ def supported_features_mapping(


class FeaturesManager:
_TASKS_TO_AUTOMODELS = {}
_TASKS_TO_TF_AUTOMODELS = {}
if is_torch_available():
_TASKS_TO_AUTOMODELS = {
"default": AutoModel,
Expand All @@ -94,8 +96,8 @@ class FeaturesManager:
"question-answering": AutoModelForQuestionAnswering,
"image-classification": AutoModelForImageClassification,
}
elif is_tf_available():
_TASKS_TO_AUTOMODELS = {
if is_tf_available():
_TASKS_TO_TF_AUTOMODELS = {
"default": TFAutoModel,
"masked-lm": TFAutoModelForMaskedLM,
"causal-lm": TFAutoModelForCausalLM,
Expand All @@ -105,8 +107,6 @@ class FeaturesManager:
"multiple-choice": TFAutoModelForMultipleChoice,
"question-answering": TFAutoModelForQuestionAnswering,
}
else:
_TASKS_TO_AUTOMODELS = {}

# Set of model topologies we support associated to the features supported by each topology and the factory
_SUPPORTED_MODEL_TYPE = {
Expand Down Expand Up @@ -257,11 +257,13 @@ def get_supported_features_for_model_type(
model_type: str, model_name: Optional[str] = None
) -> Dict[str, Callable[[PretrainedConfig], OnnxConfig]]:
"""
Try to retrieve the feature -> OnnxConfig constructor map from the model type.
Tries to retrieve the feature -> OnnxConfig constructor map from the model type.

Args:
model_type: The model type to retrieve the supported features for.
model_name: The name attribute of the model object, only used for the exception message.
model_type (`str`):
The model type to retrieve the supported features for.
model_name (`str`, *optional*):
The name attribute of the model object, only used for the exception message.

Returns:
The dictionary mapping each feature to a corresponding OnnxConfig constructor.
Expand All @@ -281,45 +283,73 @@ def feature_to_task(feature: str) -> str:
return feature.replace("-with-past", "")

@staticmethod
def get_model_class_for_feature(feature: str) -> Type:
def _validate_framework_choice(framework: str):
"""
Validates if the framework requested for the export is both correct and available, otherwise throws an
exception.
"""
if framework not in ["pt", "tf"]:
raise ValueError(
f"Only two frameworks are supported for ONNX export: pt or tf, but {framework} was provided."
)
elif framework == "pt" and not is_torch_available():
raise RuntimeError("Cannot export model to ONNX using PyTorch because no PyTorch package was found.")
elif framework == "tf" and not is_tf_available():
raise RuntimeError("Cannot export model to ONNX using TensorFlow because no TensorFlow package was found.")

@staticmethod
def get_model_class_for_feature(feature: str, framework: str = "pt") -> Type:
"""
Attempt to retrieve an AutoModel class from a feature name.
Attempts to retrieve an AutoModel class from a feature name.

Args:
feature: The feature required.
feature (`str`):
The feature required.
framework (`str`, *optional*, defaults to `"pt"`):
The framework to use for the export.

Returns:
The AutoModel class corresponding to the feature.
"""
task = FeaturesManager.feature_to_task(feature)
if task not in FeaturesManager._TASKS_TO_AUTOMODELS:
FeaturesManager._validate_framework_choice(framework)
if framework == "pt":
task_to_automodel = FeaturesManager._TASKS_TO_AUTOMODELS
else:
task_to_automodel = FeaturesManager._TASKS_TO_TF_AUTOMODELS
if task not in task_to_automodel:
raise KeyError(
f"Unknown task: {feature}. "
f"Possible values are {list(FeaturesManager._TASKS_TO_AUTOMODELS.values())}"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens here if only tensorflow is installed?

If I understand correctly, since _TASKS_TO_AUTOMODELS is an empty dict by default here, this error won't show the possible values for a pure tensorflow env because we should be accessing _TASKS_TO_TF_AUTOMODELS in that case

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I added a method called _validate_framework_choice that should take care of that.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

)
return FeaturesManager._TASKS_TO_AUTOMODELS[task]
return task_to_automodel[task]

def get_model_from_feature(feature: str, model: str) -> Union[PreTrainedModel, TFPreTrainedModel]:
def get_model_from_feature(
feature: str, model: str, framework: str = "pt"
) -> Union[PreTrainedModel, TFPreTrainedModel]:
"""
Attempt to retrieve a model from a model's name and the feature to be enabled.
Attempts to retrieve a model from a model's name and the feature to be enabled.

Args:
feature: The feature required.
model: The name of the model to export.
feature (`str`):
The feature required.
model (`str`):
The name of the model to export.
framework (`str`, *optional*, defaults to `"pt"`):
The framework to use for the export.

Returns:
The instance of the model.

"""
# If PyTorch and TensorFlow are installed in the same environment, we
# load an AutoModel class by default
model_class = FeaturesManager.get_model_class_for_feature(feature)
model_class = FeaturesManager.get_model_class_for_feature(feature, framework)
try:
model = model_class.from_pretrained(model)
# Load TensorFlow weights in an AutoModel instance if PyTorch and
# TensorFlow are installed in the same environment
except OSError:
model = model_class.from_pretrained(model, from_tf=True)
if framework == "pt":
model = model_class.from_pretrained(model, from_tf=True)
else:
model = model_class.from_pretrained(model, from_pt=True)
return model

@staticmethod
Expand Down