From 12ba701cdf5aa4eb4417ec07052d1319f8318b65 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Wed, 9 Mar 2022 15:30:39 +0100 Subject: [PATCH 1/2] Can choose framework for ONNX export --- src/transformers/onnx/__main__.py | 5 ++- src/transformers/onnx/features.py | 55 ++++++++++++++++++++++--------- 2 files changed, 43 insertions(+), 17 deletions(-) diff --git a/src/transformers/onnx/__main__.py b/src/transformers/onnx/__main__.py index 6686626ea4bd..22263ff3e08b 100644 --- a/src/transformers/onnx/__main__.py +++ b/src/transformers/onnx/__main__.py @@ -38,6 +38,9 @@ def main(): parser.add_argument( "--atol", type=float, default=None, help="Absolute difference tolerence when validating the model." ) + parser.add_argument( + "--framework", type=str, choices=["pt", "tf"], default="pt", help="The framework to use for the ONNX export." + ) parser.add_argument("output", type=Path, help="Path indicating where to store generated ONNX model.") # Retrieve CLI arguments @@ -58,7 +61,7 @@ def main(): raise ValueError(f"Unsupported model type: {config.model_type}") # Allocate the model - model = FeaturesManager.get_model_from_feature(args.feature, args.model) + model = FeaturesManager.get_model_from_feature(args.feature, args.model, framework=args.framework) model_kind, model_onnx_config = FeaturesManager.check_supported_model_or_raise(model, feature=args.feature) onnx_config = model_onnx_config(model.config) diff --git a/src/transformers/onnx/features.py b/src/transformers/onnx/features.py index 41a42d944b75..2b9fefe70ee8 100644 --- a/src/transformers/onnx/features.py +++ b/src/transformers/onnx/features.py @@ -37,7 +37,7 @@ AutoModelForSequenceClassification, AutoModelForTokenClassification, ) -elif is_tf_available(): +if is_tf_available(): from transformers.models.auto import ( TFAutoModel, TFAutoModelForCausalLM, @@ -48,7 +48,7 @@ TFAutoModelForSequenceClassification, TFAutoModelForTokenClassification, ) -else: +if not is_torch_available() and not is_tf_available(): logger.warning( "The ONNX export features are only supported for PyTorch or TensorFlow. You will not be able to export models without one of these libraries installed." ) @@ -82,6 +82,8 @@ def supported_features_mapping( class FeaturesManager: + _TASKS_TO_AUTOMODELS = {} + _TASKS_TO_TF_AUTOMODELS = {} if is_torch_available(): _TASKS_TO_AUTOMODELS = { "default": AutoModel, @@ -94,8 +96,8 @@ class FeaturesManager: "question-answering": AutoModelForQuestionAnswering, "image-classification": AutoModelForImageClassification, } - elif is_tf_available(): - _TASKS_TO_AUTOMODELS = { + if is_tf_available(): + _TASKS_TO_TF_AUTOMODELS = { "default": TFAutoModel, "masked-lm": TFAutoModelForMaskedLM, "causal-lm": TFAutoModelForCausalLM, @@ -105,8 +107,6 @@ class FeaturesManager: "multiple-choice": TFAutoModelForMultipleChoice, "question-answering": TFAutoModelForQuestionAnswering, } - else: - _TASKS_TO_AUTOMODELS = {} # Set of model topologies we support associated to the features supported by each topology and the factory _SUPPORTED_MODEL_TYPE = { @@ -281,45 +281,68 @@ def feature_to_task(feature: str) -> str: return feature.replace("-with-past", "") @staticmethod - def get_model_class_for_feature(feature: str) -> Type: + def _validate_framework_choice(framework: str): + """ + Validates if the framework requested for the export is both correct and available, otherwise throws an + exception. + """ + if framework not in ["pt", "tf"]: + raise ValueError( + f"Only two frameworks are supported for ONNX export: pt or tf, but {framework} was provided." + ) + elif framework == "pt" and not is_torch_available(): + raise RuntimeError("Cannot export model to ONNX using PyTorch because no PyTorch package was found.") + elif framework == "tf" and not is_tf_available(): + raise RuntimeError("Cannot export model to ONNX using TensorFlow because no TensorFlow package was found.") + + @staticmethod + def get_model_class_for_feature(feature: str, framework: str = "pt") -> Type: """ Attempt to retrieve an AutoModel class from a feature name. Args: feature: The feature required. + framework: The framework to use for the export. Returns: The AutoModel class corresponding to the feature. """ task = FeaturesManager.feature_to_task(feature) - if task not in FeaturesManager._TASKS_TO_AUTOMODELS: + FeaturesManager._validate_framework_choice(framework) + if framework == "pt": + task_to_automodel = FeaturesManager._TASKS_TO_AUTOMODELS + else: + task_to_automodel = FeaturesManager._TASKS_TO_TF_AUTOMODELS + if task not in task_to_automodel: raise KeyError( f"Unknown task: {feature}. " f"Possible values are {list(FeaturesManager._TASKS_TO_AUTOMODELS.values())}" ) - return FeaturesManager._TASKS_TO_AUTOMODELS[task] + return task_to_automodel[task] - def get_model_from_feature(feature: str, model: str) -> Union[PreTrainedModel, TFPreTrainedModel]: + def get_model_from_feature( + feature: str, model: str, framework: str = "pt" + ) -> Union[PreTrainedModel, TFPreTrainedModel]: """ Attempt to retrieve a model from a model's name and the feature to be enabled. Args: feature: The feature required. model: The name of the model to export. + framework: The framework to use for the export. Returns: The instance of the model. """ - # If PyTorch and TensorFlow are installed in the same environment, we - # load an AutoModel class by default - model_class = FeaturesManager.get_model_class_for_feature(feature) + model_class = FeaturesManager.get_model_class_for_feature(feature, framework) try: model = model_class.from_pretrained(model) - # Load TensorFlow weights in an AutoModel instance if PyTorch and - # TensorFlow are installed in the same environment except OSError: - model = model_class.from_pretrained(model, from_tf=True) + if framework == "pt": + model = model_class.from_pretrained(model, from_tf=True) + else: + model = model_class.from_pretrained(model, from_pt=True) return model @staticmethod From ec48313aaace93acac46364a1437e3f4f88fb9e0 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Fri, 11 Mar 2022 17:04:19 +0100 Subject: [PATCH 2/2] Fix docstring --- src/transformers/onnx/features.py | 27 +++++++++++++++++---------- 1 file changed, 17 insertions(+), 10 deletions(-) diff --git a/src/transformers/onnx/features.py b/src/transformers/onnx/features.py index 2b9fefe70ee8..c792342330f8 100644 --- a/src/transformers/onnx/features.py +++ b/src/transformers/onnx/features.py @@ -257,11 +257,13 @@ def get_supported_features_for_model_type( model_type: str, model_name: Optional[str] = None ) -> Dict[str, Callable[[PretrainedConfig], OnnxConfig]]: """ - Try to retrieve the feature -> OnnxConfig constructor map from the model type. + Tries to retrieve the feature -> OnnxConfig constructor map from the model type. Args: - model_type: The model type to retrieve the supported features for. - model_name: The name attribute of the model object, only used for the exception message. + model_type (`str`): + The model type to retrieve the supported features for. + model_name (`str`, *optional*): + The name attribute of the model object, only used for the exception message. Returns: The dictionary mapping each feature to a corresponding OnnxConfig constructor. @@ -298,11 +300,13 @@ def _validate_framework_choice(framework: str): @staticmethod def get_model_class_for_feature(feature: str, framework: str = "pt") -> Type: """ - Attempt to retrieve an AutoModel class from a feature name. + Attempts to retrieve an AutoModel class from a feature name. Args: - feature: The feature required. - framework: The framework to use for the export. + feature (`str`): + The feature required. + framework (`str`, *optional*, defaults to `"pt"`): + The framework to use for the export. Returns: The AutoModel class corresponding to the feature. @@ -324,12 +328,15 @@ def get_model_from_feature( feature: str, model: str, framework: str = "pt" ) -> Union[PreTrainedModel, TFPreTrainedModel]: """ - Attempt to retrieve a model from a model's name and the feature to be enabled. + Attempts to retrieve a model from a model's name and the feature to be enabled. Args: - feature: The feature required. - model: The name of the model to export. - framework: The framework to use for the export. + feature (`str`): + The feature required. + model (`str`): + The name of the model to export. + framework (`str`, *optional*, defaults to `"pt"`): + The framework to use for the export. Returns: The instance of the model.