Skip to content

Commit

Permalink
Enable ONNX export when PyTorch and TensorFlow installed in the same …
Browse files Browse the repository at this point in the history
…environment (#15625)
  • Loading branch information
lewtun committed Feb 11, 2022
1 parent 6cf06d1 commit 7e4844f
Showing 1 changed file with 9 additions and 1 deletion.
10 changes: 9 additions & 1 deletion src/transformers/onnx/features.py
Expand Up @@ -303,8 +303,16 @@ def get_model_from_feature(feature: str, model: str) -> Union[PreTrainedModel, T
The instance of the model.
"""
# If PyTorch and TensorFlow are installed in the same environment, we
# load an AutoModel class by default
model_class = FeaturesManager.get_model_class_for_feature(feature)
return model_class.from_pretrained(model)
try:
model = model_class.from_pretrained(model)
# Load TensorFlow weights in an AutoModel instance if PyTorch and
# TensorFlow are installed in the same environment
except OSError:
model = model_class.from_pretrained(model, from_tf=True)
return model

@staticmethod
def check_supported_model_or_raise(
Expand Down

0 comments on commit 7e4844f

Please sign in to comment.