Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add XLNet OnnxConfig #17027

Closed
wants to merge 10 commits into from
Closed

Add XLNet OnnxConfig #17027

wants to merge 10 commits into from

Conversation

sijunhe
Copy link
Contributor

@sijunhe sijunhe commented Apr 30, 2022

What does this PR do?

  1. Add XLNet OnnxConfig to make this model available for conversion.
  2. In order to make the onnx export work, I had to remove the **kwargs argument in the forward function of the XLNet models. Seems like the **kwargs was on deprecation warning anyway and removing it didn't break any tests. Here is the reproduction and the error log of the OnnxExport if the **kwargs argument doesn't get removed.
from typing import Mapping, OrderedDict 
from pathlib import Path

from transformers.onnx import OnnxConfig, export
from transformers import AutoTokenizer, AutoModel, AutoConfig

class XLNetOnnxConfig(OnnxConfig):
    @property
    def inputs(self) -> Mapping[str, Mapping[int, str]]:
        if self.task == "multiple-choice":
            dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"}
        else:
            dynamic_axis = {0: "batch", 1: "sequence"}
        return OrderedDict(
            [
                ("input_ids", dynamic_axis),
                ("attention_mask", dynamic_axis),
                ("token_type_ids", dynamic_axis)
            ]
        )
 
config = AutoConfig.from_pretrained("xlnet-base-cased")
onnx_config = XLNetOnnxConfig(config, task="sequence-classification")

onnx_path = Path("model.onnx")
base_model = AutoModel.from_pretrained("xlnet-base-cased")
tokenizer = AutoTokenizer.from_pretrained("xlnet-base-cased")

onnx_inputs, onnx_outputs = export(tokenizer, base_model, onnx_config, onnx_config.default_onnx_opset, onnx_path)
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Input In [1], in <module>
     28 base_model = AutoModel.from_pretrained("xlnet-base-cased")
     29 tokenizer = AutoTokenizer.from_pretrained("xlnet-base-cased")
---> 31 onnx_inputs, onnx_outputs = export(tokenizer, base_model, onnx_config, onnx_config.default_onnx_opset, onnx_path)

File /opt/homebrew/lib/python3.9/site-packages/transformers/onnx/convert.py:116, in export(tokenizer, model, config, opset, output)
    113     config.patch_ops()
    115     # export can works with named args but the dict containing named args as to be last element of the args tuple
--> 116     export(
    117         model,
    118         (model_inputs,),
    119         f=output.as_posix(),
    120         input_names=list(config.inputs.keys()),
    121         output_names=onnx_outputs,
    122         dynamic_axes={name: axes for name, axes in chain(config.inputs.items(), config.outputs.items())},
    123         do_constant_folding=True,
    124         use_external_data_format=config.use_external_data_format(model.num_parameters()),
    125         enable_onnx_checker=True,
    126         opset_version=opset,
    127     )
    129     config.restore_ops()
    131 return matched_inputs, onnx_outputs

File /opt/homebrew/lib/python3.9/site-packages/torch/onnx/__init__.py:316, in export(model, args, f, export_params, verbose, training, input_names, output_names, operator_export_type, opset_version, _retain_param_name, do_constant_folding, example_outputs, strip_doc_string, dynamic_axes, keep_initializers_as_inputs, custom_opsets, enable_onnx_checker, use_external_data_format)
     38 r"""
     39 Exports a model into ONNX format. If ``model`` is not a
     40 :class:`torch.jit.ScriptModule` nor a :class:`torch.jit.ScriptFunction`, this runs
   (...)
    312     model to the file ``f`` even if this is raised.
    313 """
    315 from torch.onnx import utils
--> 316 return utils.export(model, args, f, export_params, verbose, training,
    317                     input_names, output_names, operator_export_type, opset_version,
    318                     _retain_param_name, do_constant_folding, example_outputs,
    319                     strip_doc_string, dynamic_axes, keep_initializers_as_inputs,
    320                     custom_opsets, enable_onnx_checker, use_external_data_format)

File /opt/homebrew/lib/python3.9/site-packages/torch/onnx/utils.py:107, in export(model, args, f, export_params, verbose, training, input_names, output_names, operator_export_type, opset_version, _retain_param_name, do_constant_folding, example_outputs, strip_doc_string, dynamic_axes, keep_initializers_as_inputs, custom_opsets, enable_onnx_checker, use_external_data_format)
    102 if use_external_data_format is not None:
    103     warnings.warn("`use_external_data_format' is deprecated and ignored. Will be removed in next "
    104                   "PyTorch release. The code will work as it is False if models are not larger than 2GB, "
    105                   "Otherwise set to False because of size limits imposed by Protocol Buffers.")
--> 107 _export(model, args, f, export_params, verbose, training, input_names, output_names,
    108         operator_export_type=operator_export_type, opset_version=opset_version,
    109         do_constant_folding=do_constant_folding, example_outputs=example_outputs,
    110         dynamic_axes=dynamic_axes, keep_initializers_as_inputs=keep_initializers_as_inputs,
    111         custom_opsets=custom_opsets, use_external_data_format=use_external_data_format)

File /opt/homebrew/lib/python3.9/site-packages/torch/onnx/utils.py:724, in _export(model, args, f, export_params, verbose, training, input_names, output_names, operator_export_type, export_type, example_outputs, opset_version, do_constant_folding, dynamic_axes, keep_initializers_as_inputs, fixed_batch_size, custom_opsets, add_node_names, use_external_data_format, onnx_shape_inference)
    720     dynamic_axes = {}
    721 _validate_dynamic_axes(dynamic_axes, model, input_names, output_names)
    723 graph, params_dict, torch_out = \
--> 724     _model_to_graph(model, args, verbose, input_names,
    725                     output_names, operator_export_type,
    726                     example_outputs, val_do_constant_folding,
    727                     fixed_batch_size=fixed_batch_size,
    728                     training=training,
    729                     dynamic_axes=dynamic_axes)
    731 # TODO: Don't allocate a in-memory string for the protobuf
    732 defer_weight_export = export_type is not ExportTypes.PROTOBUF_FILE

File /opt/homebrew/lib/python3.9/site-packages/torch/onnx/utils.py:493, in _model_to_graph(model, args, verbose, input_names, output_names, operator_export_type, example_outputs, do_constant_folding, _disable_torch_constant_prop, fixed_batch_size, training, dynamic_axes)
    490 if isinstance(args, (torch.Tensor, int, float, bool)):
    491     args = (args, )
--> 493 graph, params, torch_out, module = _create_jit_graph(model, args)
    495 params_dict = _get_named_param_dict(graph, params)
    497 graph = _optimize_graph(graph, operator_export_type,
    498                         _disable_torch_constant_prop=_disable_torch_constant_prop,
    499                         fixed_batch_size=fixed_batch_size, params_dict=params_dict,
    500                         dynamic_axes=dynamic_axes, input_names=input_names,
    501                         module=module)

File /opt/homebrew/lib/python3.9/site-packages/torch/onnx/utils.py:437, in _create_jit_graph(model, args)
    435     return graph, params, torch_out, None
    436 else:
--> 437     graph, torch_out = _trace_and_get_graph_from_model(model, args)
    438     state_dict = _unique_state_dict(model)
    439     params = list(state_dict.values())

File /opt/homebrew/lib/python3.9/site-packages/torch/onnx/utils.py:388, in _trace_and_get_graph_from_model(model, args)
    381 def _trace_and_get_graph_from_model(model, args):
    382 
    383     # A basic sanity check: make sure the state_dict keys are the same
    384     # before and after running the model.  Fail fast!
    385     orig_state_dict_keys = _unique_state_dict(model).keys()
    387     trace_graph, torch_out, inputs_states = \
--> 388         torch.jit._get_trace_graph(model, args, strict=False, _force_outplace=False, _return_inputs_states=True)
    389     warn_on_static_input_change(inputs_states)
    391     if orig_state_dict_keys != _unique_state_dict(model).keys():

File /opt/homebrew/lib/python3.9/site-packages/torch/jit/_trace.py:1166, in _get_trace_graph(f, args, kwargs, strict, _force_outplace, return_inputs, _return_inputs_states)
   1164 if not isinstance(args, tuple):
   1165     args = (args,)
-> 1166 outs = ONNXTracedModule(f, strict, _force_outplace, return_inputs, _return_inputs_states)(*args, **kwargs)
   1167 return outs

File /opt/homebrew/lib/python3.9/site-packages/torch/nn/modules/module.py:1102, in Module._call_impl(self, *input, **kwargs)
   1098 # If we don't have any hooks, we want to skip the rest of the logic in
   1099 # this function, and just call forward.
   1100 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1101         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102     return forward_call(*input, **kwargs)
   1103 # Do not call functions when jit is used
   1104 full_backward_hooks, non_full_backward_hooks = [], []

File /opt/homebrew/lib/python3.9/site-packages/torch/jit/_trace.py:127, in ONNXTracedModule.forward(self, *args)
    124     else:
    125         return tuple(out_vars)
--> 127 graph, out = torch._C._create_graph_by_tracing(
    128     wrapper,
    129     in_vars + module_state,
    130     _create_interpreter_name_lookup_fn(),
    131     self.strict,
    132     self._force_outplace,
    133 )
    135 if self._return_inputs:
    136     return graph, outs[0], ret_inputs[0]

File /opt/homebrew/lib/python3.9/site-packages/torch/jit/_trace.py:118, in ONNXTracedModule.forward.<locals>.wrapper(*args)
    116 if self._return_inputs_states:
    117     inputs_states.append(_unflatten(in_args, in_desc))
--> 118 outs.append(self.inner(*trace_inputs))
    119 if self._return_inputs_states:
    120     inputs_states[0] = (inputs_states[0], trace_inputs)

File /opt/homebrew/lib/python3.9/site-packages/torch/nn/modules/module.py:1102, in Module._call_impl(self, *input, **kwargs)
   1098 # If we don't have any hooks, we want to skip the rest of the logic in
   1099 # this function, and just call forward.
   1100 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1101         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102     return forward_call(*input, **kwargs)
   1103 # Do not call functions when jit is used
   1104 full_backward_hooks, non_full_backward_hooks = [], []

File /opt/homebrew/lib/python3.9/site-packages/torch/nn/modules/module.py:1090, in Module._slow_forward(self, *input, **kwargs)
   1088         recording_scopes = False
   1089 try:
-> 1090     result = self.forward(*input, **kwargs)
   1091 finally:
   1092     if recording_scopes:

TypeError: forward() takes from 1 to 14 positional arguments but 15 were given​

Fixes #16308

Before submitting

Who can review?

@chainyo for the OnnxConfig
@patrickvonplaten and @sgugger for the changes in modeling_xlnet.py

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@chainyo
Copy link
Contributor

chainyo commented May 1, 2022

Hi @sijunhe Nice PR, but could you rebase tre branch to avoid getting all the recent commits on this PR ?

@sgugger sgugger requested a review from lewtun May 2, 2022 11:53
@lewtun
Copy link
Member

lewtun commented May 3, 2022

Hi @sijunhe thanks for this PR! Indeed as @chainyo suggests, could you please rebase on main so that it is a bit easier to review the changes from your PR?

@sijunhe
Copy link
Contributor Author

sijunhe commented May 3, 2022

Opps! Sorry about that. Merged! @lewtun @chainyo

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

@sijunhe
Copy link
Contributor Author

sijunhe commented May 13, 2022

Any progress here? @lewtun

@@ -1081,7 +1080,6 @@ def forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs, # delete after depreciation warning is removed
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it will break any other usage of this architecture, isn't it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had to do this because **kwargs breaks the onnx export, as I mentioned in the PR description. It did pass all the unit test and I think the deprecation warning has been up for a while.

Comment on lines 1093 to 1100
if "use_cache" in kwargs:
warnings.warn(
"The `use_cache` argument is deprecated and will be removed in a future version, use `use_mems` instead.",
FutureWarning,
)
use_mems = kwargs["use_cache"]

Copy link
Contributor

@chainyo chainyo May 13, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should probably keep this for another PR but it's probably the right timing to change use_cache to use_mems.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should keep this here. This has nothing to do really with this PR IMO

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah sorry, ok I understand! You need to remove the kwargs to make it work for ONNX. Hmm, I sadly don't think we can do this before Transformers v5. @sgugger @LysandreJik what do you think?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nope, no breaking change until v5 indeed.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, let's keep this PR open until v5 is released. Nevertheless, thank you for working on this @sijunhe - let's revisit it once we're able to safely remove the kwargs from the forward pass!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It could be awesome yes!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the feedback @LysandreJik - this is really helpful!

As you suggest, with monkey patching we could have something like the following inside the export_pytorch() (and possibly export_tensorflow()) methods:

model.forward = forward_without_kwargs(model.forward)

where forward_without_kwargs() is a function that wraps the original forward pass to strip out the kwargs.

WDYT @sijunhe - would you like to have a go at implementing this?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

couldn't we otherwise just remove kwargs and replace it with use_cache=None and then raise a warning if use_cache is not None?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would still ensure bcp

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would expose the deprecated arg though. So if there is a solution without removing the kwargs, I'd prefer it.

@sijunhe
Copy link
Contributor Author

sijunhe commented May 21, 2022

Thanks for the review folks.

I tried what @lewtun suggested about stripping the kwargs but I couldn't really make it work.
model.forward = forward_without_kwargs(model.forward) means forward_without_kwargs would need to change the input signature of model.forward and I didn't know if python can do that. If I try to return a new function based on model.forward, the call then becomes a infinite recursion.

Instead I took @patrickvonplaten's suggestion and replace **kwargs with a single use_cache arg.

@patrickvonplaten
Copy link
Contributor

Thanks for the review folks.

I tried what @lewtun suggested about stripping the kwargs but I couldn't really make it work. model.forward = forward_without_kwargs(model.forward) means forward_without_kwargs would need to change the input signature of model.forward and I didn't know if python can do that. If I try to return a new function based on model.forward, the call then becomes a infinite recursion.

Instead I took @patrickvonplaten's suggestion and replace **kwargs with a single use_cache arg.

Since it's an edge case I'm ok with this! Thanks for making the change @sijunhe - what do you think @LysandreJik @sgugger
we should add to the doc string that the param is deprecated as well I guess

@sgugger
Copy link
Collaborator

sgugger commented May 23, 2022

No, the param is not documented since it's deprecated, and it should stay that way IMO.

@lewtun
Copy link
Member

lewtun commented May 23, 2022

If I'm not mistaken, can't we define a wrapper function to strip out **kwargs from the function signature? This is roughly what I had in mind to handle the forward pass:

from transformers import AutoModel
import inspect
import functools

def forward_without_kwargs(forward):
  @functools.wraps(forward)
  def wrapper(*args, **kwargs):
      return forward(*args, **kwargs)

  # Override signature and strip out kwargs
  sig = inspect.signature(forward)
  sig = sig.replace(parameters=tuple(sig.parameters.values())[:-1])
  wrapper.__signature__ = sig

  return wrapper

# Load an XLNet checkpoint
model = AutoModel.from_pretrained("xlnet-base-cased")
# Has kwargs
inspect.signature(model.forward)
# Has no kwargs
model.forward = forward_without_kwargs(model.forward)
inspect.signature(model.forward)

This function could live in onnx/utils.py and then be called within the export_pytorch() function by checking if kwargs is present in the model's forward signature and stripping it out if so.

Of course, this would also need to be tested properly - just an idea :)

@patrickvonplaten
Copy link
Contributor

If I'm not mistaken, can't we define a wrapper function to strip out **kwargs from the function signature? This is roughly what I had in mind to handle the forward pass:

Also fine with me

@github-actions
Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot closed this Jun 27, 2022
@lewtun lewtun reopened this Sep 30, 2022
@github-actions github-actions bot closed this Oct 8, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

ONNXConfig: Add a configuration for all available models
7 participants