Skip to content

Commit

Permalink
ORTModelForFeatureExtraction always exports as transformers models (#…
Browse files Browse the repository at this point in the history
…1684)

fix
  • Loading branch information
fxmarty committed Feb 6, 2024
1 parent da6f9e2 commit 2a789d6
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 2 deletions.
54 changes: 54 additions & 0 deletions optimum/onnxruntime/modeling_ort.py
Original file line number Diff line number Diff line change
Expand Up @@ -1002,6 +1002,60 @@ def forward(
# converts output to namedtuple for pipelines post-processing
return BaseModelOutput(last_hidden_state=last_hidden_state)

@classmethod
def _export(
cls,
model_id: str,
config: "PretrainedConfig",
use_auth_token: Optional[Union[bool, str]] = None,
revision: Optional[str] = None,
force_download: bool = False,
cache_dir: Optional[str] = None,
subfolder: str = "",
local_files_only: bool = False,
trust_remote_code: bool = False,
provider: str = "CPUExecutionProvider",
session_options: Optional[ort.SessionOptions] = None,
provider_options: Optional[Dict[str, Any]] = None,
use_io_binding: Optional[bool] = None,
task: Optional[str] = None,
) -> "ORTModel":
if task is None:
task = cls._auto_model_to_task(cls.auto_model_class)

save_dir = TemporaryDirectory()
save_dir_path = Path(save_dir.name)

# ORTModelForFeatureExtraction works with Transformers type of models, thus even sentence-transformers models are loaded as such.
main_export(
model_name_or_path=model_id,
output=save_dir_path,
task=task,
do_validation=False,
no_post_process=True,
subfolder=subfolder,
revision=revision,
cache_dir=cache_dir,
use_auth_token=use_auth_token,
local_files_only=local_files_only,
force_download=force_download,
trust_remote_code=trust_remote_code,
library_name="transformers",
)

config.save_pretrained(save_dir_path)
maybe_save_preprocessors(model_id, save_dir_path, src_subfolder=subfolder)

return cls._from_pretrained(
save_dir_path,
config,
use_io_binding=use_io_binding,
model_save_dir=save_dir,
provider=provider,
session_options=session_options,
provider_options=provider_options,
)


MASKED_LM_EXAMPLE = r"""
Example of feature extraction:
Expand Down
4 changes: 2 additions & 2 deletions tests/onnxruntime/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -1227,7 +1227,7 @@ def test_trust_remote_code(self):
class ORTModelForQuestionAnsweringIntegrationTest(ORTModelTestMixin):
SUPPORTED_ARCHITECTURES = [
"albert",
"bart",
"bart",
"bert",
# "big_bird",
# "bigbird_pegasus",
Expand Down Expand Up @@ -1592,7 +1592,7 @@ def test_compare_to_io_binding(self, model_arch):
class ORTModelForSequenceClassificationIntegrationTest(ORTModelTestMixin):
SUPPORTED_ARCHITECTURES = [
"albert",
"bart",
"bart",
"bert",
# "big_bird",
# "bigbird_pegasus",
Expand Down

0 comments on commit 2a789d6

Please sign in to comment.